Added real-time sync with Sonarr v3 and Radarr v3 by feeding from SignalR feeds. You can now reduce frequency of sync tasks to something like once a day.

pull/1405/head
morpheus65535 4 years ago committed by GitHub
parent 72b6ab3c6a
commit 44c51b2e2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -88,8 +88,8 @@ defaults = {
'full_update_day': '6',
'full_update_hour': '4',
'only_monitored': 'False',
'series_sync': '1',
'episodes_sync': '5',
'series_sync': '60',
'episodes_sync': '60',
'excluded_tags': '[]',
'excluded_series_types': '[]'
},
@ -103,7 +103,7 @@ defaults = {
'full_update_day': '6',
'full_update_hour': '5',
'only_monitored': 'False',
'movies_sync': '5',
'movies_sync': '60',
'excluded_tags': '[]'
},
'proxy': {
@ -208,6 +208,19 @@ str_keys = ['chmod']
empty_values = ['', 'None', 'null', 'undefined', None, []]
# Increase Sonarr and Radarr sync interval since we now use SignalR feed to update in real time
if int(settings.sonarr.series_sync) < 15:
settings.sonarr.series_sync = "60"
if int(settings.sonarr.episodes_sync) < 15:
settings.sonarr.episodes_sync = "60"
if int(settings.radarr.movies_sync) < 15:
settings.radarr.movies_sync = "60"
if os.path.exists(os.path.join(args.config_dir, 'config', 'config.ini')):
with open(os.path.join(args.config_dir, 'config', 'config.ini'), 'w+') as handle:
settings.write(handle)
def get_settings():
result = dict()
sections = settings.sections()
@ -255,6 +268,8 @@ def save_settings(settings_items):
configure_debug = False
configure_captcha = False
update_schedule = False
sonarr_changed = False
radarr_changed = False
update_path_map = False
configure_proxy = False
exclusion_updated = False
@ -309,6 +324,14 @@ def save_settings(settings_items):
'settings-general-upgrade_frequency']:
update_schedule = True
if key in ['settings-general-use_sonarr', 'settings-sonarr-ip', 'settings-sonarr-port',
'settings-sonarr-base_url', 'settings-sonarr-ssl', 'settings-sonarr-apikey']:
sonarr_changed = True
if key in ['settings-general-use_radarr', 'settings-radarr-ip', 'settings-radarr-port',
'settings-radarr-base_url', 'settings-radarr-ssl', 'settings-radarr-apikey']:
radarr_changed = True
if key in ['settings-general-path_mappings', 'settings-general-path_mappings_movie']:
update_path_map = True
@ -388,6 +411,14 @@ def save_settings(settings_items):
from api import scheduler
scheduler.update_configurable_tasks()
if sonarr_changed:
from signalr_client import sonarr_signalr_client
sonarr_signalr_client.restart()
if radarr_changed:
from signalr_client import radarr_signalr_client
radarr_signalr_client.restart()
if update_path_map:
from helper import path_mappings
path_mappings.update()

@ -34,107 +34,28 @@ def sync_episodes():
altered_episodes = []
# Get sonarrId for each series from database
seriesIdList = database.execute("SELECT sonarrSeriesId, title FROM table_shows")
for i, seriesId in enumerate(seriesIdList):
seriesIdList = get_series_from_sonarr_api(url=url_sonarr(), apikey_sonarr=apikey_sonarr)
for seriesId in seriesIdList:
# Get episodes data for a series from Sonarr
url_sonarr_api_episode = url_sonarr() + "/api/episode?seriesId=" + str(seriesId['sonarrSeriesId']) + "&apikey=" + apikey_sonarr
try:
r = requests.get(url_sonarr_api_episode, timeout=60, verify=False, headers=headers)
r.raise_for_status()
except requests.exceptions.HTTPError as errh:
logging.exception("BAZARR Error trying to get episodes from Sonarr. Http error.")
return
except requests.exceptions.ConnectionError as errc:
logging.exception("BAZARR Error trying to get episodes from Sonarr. Connection Error.")
return
except requests.exceptions.Timeout as errt:
logging.exception("BAZARR Error trying to get episodes from Sonarr. Timeout Error.")
return
except requests.exceptions.RequestException as err:
logging.exception("BAZARR Error trying to get episodes from Sonarr.")
return
episodes = get_episodes_from_sonarr_api(url=url_sonarr(), apikey_sonarr=apikey_sonarr,
series_id=seriesId['sonarrSeriesId'])
if not episodes:
continue
else:
for episode in r.json():
for episode in episodes:
if 'hasFile' in episode:
if episode['hasFile'] is True:
if 'episodeFile' in episode:
if episode['episodeFile']['size'] > 20480:
# Add shows in Sonarr to current shows list
if 'sceneName' in episode['episodeFile']:
sceneName = episode['episodeFile']['sceneName']
else:
sceneName = None
try:
format, resolution = episode['episodeFile']['quality']['quality']['name'].split('-')
except:
format = episode['episodeFile']['quality']['quality']['name']
try:
resolution = str(episode['episodeFile']['quality']['quality']['resolution']) + 'p'
except:
resolution = None
if 'mediaInfo' in episode['episodeFile']:
if 'videoCodec' in episode['episodeFile']['mediaInfo']:
videoCodec = episode['episodeFile']['mediaInfo']['videoCodec']
videoCodec = SonarrFormatVideoCodec(videoCodec)
else: videoCodec = None
if 'audioCodec' in episode['episodeFile']['mediaInfo']:
audioCodec = episode['episodeFile']['mediaInfo']['audioCodec']
audioCodec = SonarrFormatAudioCodec(audioCodec)
else: audioCodec = None
else:
videoCodec = None
audioCodec = None
audio_language = []
if 'language' in episode['episodeFile'] and len(episode['episodeFile']['language']):
item = episode['episodeFile']['language']
if isinstance(item, dict):
if 'name' in item:
audio_language.append(item['name'])
else:
audio_language = database.execute("SELECT audio_language FROM table_shows WHERE "
"sonarrSeriesId=?", (episode['seriesId'],),
only_one=True)['audio_language']
# Add episodes in sonarr to current episode list
current_episodes_sonarr.append(episode['id'])
# Parse episdoe data
if episode['id'] in current_episodes_db_list:
episodes_to_update.append({'sonarrSeriesId': episode['seriesId'],
'sonarrEpisodeId': episode['id'],
'title': episode['title'],
'path': episode['episodeFile']['path'],
'season': episode['seasonNumber'],
'episode': episode['episodeNumber'],
'scene_name': sceneName,
'monitored': str(bool(episode['monitored'])),
'format': format,
'resolution': resolution,
'video_codec': videoCodec,
'audio_codec': audioCodec,
'episode_file_id': episode['episodeFile']['id'],
'audio_language': str(audio_language),
'file_size': episode['episodeFile']['size']})
episodes_to_update.append(episodeParser(episode))
else:
episodes_to_add.append({'sonarrSeriesId': episode['seriesId'],
'sonarrEpisodeId': episode['id'],
'title': episode['title'],
'path': episode['episodeFile']['path'],
'season': episode['seasonNumber'],
'episode': episode['episodeNumber'],
'scene_name': sceneName,
'monitored': str(bool(episode['monitored'])),
'format': format,
'resolution': resolution,
'video_codec': videoCodec,
'audio_codec': audioCodec,
'episode_file_id': episode['episodeFile']['id'],
'audio_language': str(audio_language),
'file_size': episode['episodeFile']['size']})
episodes_to_add.append(episodeParser(episode))
# Remove old episodes from DB
removed_episodes = list(set(current_episodes_db_list) - set(current_episodes_sonarr))
@ -176,7 +97,8 @@ def sync_episodes():
added_episode['monitored']])
event_stream(type='episode', payload=added_episode['sonarrEpisodeId'])
else:
logging.debug('BAZARR unable to insert this episode into the database:{}'.format(path_mappings.path_replace(added_episode['path'])))
logging.debug('BAZARR unable to insert this episode into the database:{}'.format(
path_mappings.path_replace(added_episode['path'])))
# Store subtitles for added or modified episodes
for i, altered_episode in enumerate(altered_episodes, 1):
@ -184,44 +106,219 @@ def sync_episodes():
logging.debug('BAZARR All episodes synced from Sonarr into database.')
# Search for desired subtitles if no more than 5 episodes have been added.
if len(altered_episodes) <= 5:
logging.debug("BAZARR No more than 5 episodes were added during this sync then we'll search for subtitles.")
for altered_episode in altered_episodes:
data = database.execute("SELECT table_episodes.sonarrEpisodeId, table_episodes.monitored, table_shows.tags,"
" table_shows.seriesType FROM table_episodes LEFT JOIN table_shows on "
"table_episodes.sonarrSeriesId = table_shows.sonarrSeriesId WHERE "
"sonarrEpisodeId = ?" + get_exclusion_clause('series'), (altered_episode[0],),
only_one=True)
if data:
episode_download_subtitles(data['sonarrEpisodeId'])
else:
logging.debug("BAZARR skipping download for this episode as it is excluded.")
def sync_one_episode(episode_id):
logging.debug('BAZARR syncing this specific episode from Sonarr: {}'.format(episode_id))
# Check if there's a row in database for this episode ID
existing_episode = database.execute('SELECT path FROM table_episodes WHERE sonarrEpisodeId = ?', (episode_id,),
only_one=True)
try:
# Get episode data from sonarr api
episode = None
episode_data = get_episodes_from_sonarr_api(url=url_sonarr(), apikey_sonarr=settings.sonarr.apikey,
episode_id=episode_id)
if not episode_data:
return
else:
episode = episodeParser(episode_data)
except Exception:
logging.debug('BAZARR cannot get episode returned by SignalR feed from Sonarr API.')
return
# Drop useless events
if not episode and not existing_episode:
return
# Remove episode from DB
if not episode and existing_episode:
database.execute("DELETE FROM table_episodes WHERE sonarrEpisodeId=?", (episode_id,))
event_stream(type='episode', action='delete', payload=int(episode_id))
logging.debug('BAZARR deleted this episode from the database:{}'.format(path_mappings.path_replace(
existing_episode['path'])))
return
# Update existing episodes in DB
elif episode and existing_episode:
query = dict_converter.convert(episode)
database.execute('''UPDATE table_episodes SET ''' + query.keys_update + ''' WHERE sonarrEpisodeId = ?''',
query.values + (episode['sonarrEpisodeId'],))
event_stream(type='episode', action='update', payload=int(episode_id))
logging.debug('BAZARR updated this episode into the database:{}'.format(path_mappings.path_replace(
episode['path'])))
# Insert new episodes in DB
elif episode and not existing_episode:
query = dict_converter.convert(episode)
database.execute('''INSERT OR IGNORE INTO table_episodes(''' + query.keys_insert + ''') VALUES(''' +
query.question_marks + ''')''', query.values)
event_stream(type='episode', action='update', payload=int(episode_id))
logging.debug('BAZARR inserted this episode into the database:{}'.format(path_mappings.path_replace(
episode['path'])))
# Storing existing subtitles
logging.debug('BAZARR storing subtitles for this episode: {}'.format(path_mappings.path_replace(
episode['path'])))
store_subtitles(episode['path'], path_mappings.path_replace(episode['path']))
# Downloading missing subtitles
logging.debug('BAZARR downloading missing subtitles for this episode: {}'.format(path_mappings.path_replace(
episode['path'])))
episode_download_subtitles(episode_id)
def SonarrFormatAudioCodec(audio_codec):
if audio_codec == 'AC-3':
return 'AC3'
if audio_codec == 'E-AC-3':
return 'EAC3'
if audio_codec == 'MPEG Audio':
return 'MP3'
return audio_codec
def SonarrFormatVideoCodec(video_codec):
if video_codec == 'x264' or video_codec == 'AVC':
return 'h264'
elif video_codec == 'x265' or video_codec == 'HEVC':
return 'h265'
elif video_codec.startswith('XviD'):
return 'XviD'
elif video_codec.startswith('DivX'):
return 'DivX'
elif video_codec == 'MPEG-1 Video':
return 'Mpeg'
elif video_codec == 'MPEG-2 Video':
return 'Mpeg2'
elif video_codec == 'MPEG-4 Video':
return 'Mpeg4'
elif video_codec == 'VC-1':
return 'VC1'
elif video_codec.endswith('VP6'):
return 'VP6'
elif video_codec.endswith('VP7'):
return 'VP7'
elif video_codec.endswith('VP8'):
return 'VP8'
elif video_codec.endswith('VP9'):
return 'VP9'
else:
logging.debug("BAZARR More than 5 episodes were added during this sync then we wont search for subtitles right now.")
return video_codec
def SonarrFormatAudioCodec(audioCodec):
if audioCodec == 'AC-3': return 'AC3'
if audioCodec == 'E-AC-3': return 'EAC3'
if audioCodec == 'MPEG Audio': return 'MP3'
def episodeParser(episode):
if 'hasFile' in episode:
if episode['hasFile'] is True:
if 'episodeFile' in episode:
if episode['episodeFile']['size'] > 20480:
if 'sceneName' in episode['episodeFile']:
sceneName = episode['episodeFile']['sceneName']
else:
sceneName = None
return audioCodec
audio_language = []
if 'language' in episode['episodeFile'] and len(episode['episodeFile']['language']):
item = episode['episodeFile']['language']
if isinstance(item, dict):
if 'name' in item:
audio_language.append(item['name'])
else:
audio_language = database.execute("SELECT audio_language FROM table_shows WHERE "
"sonarrSeriesId=?", (episode['seriesId'],),
only_one=True)['audio_language']
if 'mediaInfo' in episode['episodeFile']:
if 'videoCodec' in episode['episodeFile']['mediaInfo']:
videoCodec = episode['episodeFile']['mediaInfo']['videoCodec']
videoCodec = SonarrFormatVideoCodec(videoCodec)
else:
videoCodec = None
def SonarrFormatVideoCodec(videoCodec):
if videoCodec == 'x264' or videoCodec == 'AVC': return 'h264'
if videoCodec == 'x265' or videoCodec == 'HEVC': return 'h265'
if videoCodec.startswith('XviD'): return 'XviD'
if videoCodec.startswith('DivX'): return 'DivX'
if videoCodec == 'MPEG-1 Video': return 'Mpeg'
if videoCodec == 'MPEG-2 Video': return 'Mpeg2'
if videoCodec == 'MPEG-4 Video': return 'Mpeg4'
if videoCodec == 'VC-1': return 'VC1'
if videoCodec.endswith('VP6'): return 'VP6'
if videoCodec.endswith('VP7'): return 'VP7'
if videoCodec.endswith('VP8'): return 'VP8'
if videoCodec.endswith('VP9'): return 'VP9'
if 'audioCodec' in episode['episodeFile']['mediaInfo']:
audioCodec = episode['episodeFile']['mediaInfo']['audioCodec']
audioCodec = SonarrFormatAudioCodec(audioCodec)
else:
audioCodec = None
else:
videoCodec = None
audioCodec = None
return videoCodec
try:
video_format, video_resolution = episode['episodeFile']['quality']['quality']['name'].split('-')
except:
video_format = episode['episodeFile']['quality']['quality']['name']
try:
video_resolution = str(episode['episodeFile']['quality']['quality']['resolution']) + 'p'
except:
video_resolution = None
return {'sonarrSeriesId': episode['seriesId'],
'sonarrEpisodeId': episode['id'],
'title': episode['title'],
'path': episode['episodeFile']['path'],
'season': episode['seasonNumber'],
'episode': episode['episodeNumber'],
'scene_name': sceneName,
'monitored': str(bool(episode['monitored'])),
'format': video_format,
'resolution': video_resolution,
'video_codec': videoCodec,
'audio_codec': audioCodec,
'episode_file_id': episode['episodeFile']['id'],
'audio_language': str(audio_language),
'file_size': episode['episodeFile']['size']}
def get_series_from_sonarr_api(url, apikey_sonarr):
url_sonarr_api_series = url + "/api/series?apikey=" + apikey_sonarr
try:
r = requests.get(url_sonarr_api_series, timeout=60, verify=False, headers=headers)
r.raise_for_status()
except requests.exceptions.HTTPError as e:
if e.response.status_code:
raise requests.exceptions.HTTPError
logging.exception("BAZARR Error trying to get series from Sonarr. Http error.")
return
except requests.exceptions.ConnectionError:
logging.exception("BAZARR Error trying to get series from Sonarr. Connection Error.")
return
except requests.exceptions.Timeout:
logging.exception("BAZARR Error trying to get series from Sonarr. Timeout Error.")
return
except requests.exceptions.RequestException:
logging.exception("BAZARR Error trying to get series from Sonarr.")
return
else:
series_list = []
for series in r.json():
series_list.append({'sonarrSeriesId': series['id'], 'title': series['title']})
return series_list
def get_episodes_from_sonarr_api(url, apikey_sonarr, series_id=None, episode_id=None):
if series_id:
url_sonarr_api_episode = url + "/api/episode?seriesId={}&apikey=".format(series_id) + apikey_sonarr
elif episode_id:
url_sonarr_api_episode = url + "/api/episode/{}?apikey=".format(episode_id) + apikey_sonarr
else:
return
try:
r = requests.get(url_sonarr_api_episode, timeout=60, verify=False, headers=headers)
r.raise_for_status()
except requests.exceptions.HTTPError:
logging.exception("BAZARR Error trying to get episodes from Sonarr. Http error.")
return
except requests.exceptions.ConnectionError:
logging.exception("BAZARR Error trying to get episodes from Sonarr. Connection Error.")
return
except requests.exceptions.Timeout:
logging.exception("BAZARR Error trying to get episodes from Sonarr. Timeout Error.")
return
except requests.exceptions.RequestException:
logging.exception("BAZARR Error trying to get episodes from Sonarr.")
return
else:
return r.json()

@ -12,6 +12,7 @@ from get_rootfolder import check_radarr_rootfolder
from get_subtitle import movies_download_subtitles
from database import database, dict_converter, get_exclusion_clause
from event_handler import event_stream
headers = {"User-Agent": os.environ["SZ_USER_AGENT"]}
@ -43,25 +44,9 @@ def update_movies():
tagsDict = get_tags()
# Get movies data from radarr
if radarr_version.startswith('0'):
url_radarr_api_movies = url_radarr() + "/api/movie?apikey=" + apikey_radarr
else:
url_radarr_api_movies = url_radarr() + "/api/v3/movie?apikey=" + apikey_radarr
try:
r = requests.get(url_radarr_api_movies, timeout=60, verify=False, headers=headers)
r.raise_for_status()
except requests.exceptions.HTTPError as errh:
logging.exception("BAZARR Error trying to get movies from Radarr. Http error.")
return
except requests.exceptions.ConnectionError as errc:
logging.exception("BAZARR Error trying to get movies from Radarr. Connection Error.")
return
except requests.exceptions.Timeout as errt:
logging.exception("BAZARR Error trying to get movies from Radarr. Timeout Error.")
return
except requests.exceptions.RequestException as err:
logging.exception("BAZARR Error trying to get movies from Radarr.")
movies = get_movies_from_radarr_api(radarr_version=radarr_version, url=url_radarr(),
apikey_radarr=apikey_radarr)
if not movies:
return
else:
# Get current movies in DB
@ -74,152 +59,26 @@ def update_movies():
movies_to_add = []
altered_movies = []
moviesIdListLength = len(r.json())
for i, movie in enumerate(r.json(), 1):
# Build new and updated movies
for movie in movies:
if movie['hasFile'] is True:
if 'movieFile' in movie:
# Detect file separator
if movie['path'][0] == "/":
separator = "/"
else:
separator = "\\"
if movie["path"] != None and movie['movieFile']['relativePath'] != None:
try:
overview = str(movie['overview'])
except:
overview = ""
try:
poster_big = movie['images'][0]['url']
poster = os.path.splitext(poster_big)[0] + '-500' + os.path.splitext(poster_big)[1]
except:
poster = ""
try:
fanart = movie['images'][1]['url']
except:
fanart = ""
if 'sceneName' in movie['movieFile']:
sceneName = movie['movieFile']['sceneName']
else:
sceneName = None
alternativeTitles = None
if radarr_version.startswith('0'):
if 'alternativeTitles' in movie:
alternativeTitles = str([item['title'] for item in movie['alternativeTitles']])
else:
if 'alternateTitles' in movie:
alternativeTitles = str([item['title'] for item in movie['alternateTitles']])
if 'imdbId' in movie: imdbId = movie['imdbId']
else: imdbId = None
try:
format, resolution = movie['movieFile']['quality']['quality']['name'].split('-')
except:
format = movie['movieFile']['quality']['quality']['name']
try:
resolution = str(movie['movieFile']['quality']['quality']['resolution']) + 'p'
except:
resolution = None
if 'mediaInfo' in movie['movieFile']:
videoFormat = videoCodecID = videoProfile = videoCodecLibrary = None
if radarr_version.startswith('0'):
if 'videoFormat' in movie['movieFile']['mediaInfo']: videoFormat = movie['movieFile']['mediaInfo']['videoFormat']
else:
if 'videoCodec' in movie['movieFile']['mediaInfo']: videoFormat = movie['movieFile']['mediaInfo']['videoCodec']
if 'videoCodecID' in movie['movieFile']['mediaInfo']: videoCodecID = movie['movieFile']['mediaInfo']['videoCodecID']
if 'videoProfile' in movie['movieFile']['mediaInfo']: videoProfile = movie['movieFile']['mediaInfo']['videoProfile']
if 'videoCodecLibrary' in movie['movieFile']['mediaInfo']: videoCodecLibrary = movie['movieFile']['mediaInfo']['videoCodecLibrary']
videoCodec = RadarrFormatVideoCodec(videoFormat, videoCodecID, videoCodecLibrary)
audioFormat = audioCodecID = audioProfile = audioAdditionalFeatures = None
if radarr_version.startswith('0'):
if 'audioFormat' in movie['movieFile']['mediaInfo']: audioFormat = movie['movieFile']['mediaInfo']['audioFormat']
else:
if 'audioCodec' in movie['movieFile']['mediaInfo']: audioFormat = movie['movieFile']['mediaInfo']['audioCodec']
if 'audioCodecID' in movie['movieFile']['mediaInfo']: audioCodecID = movie['movieFile']['mediaInfo']['audioCodecID']
if 'audioProfile' in movie['movieFile']['mediaInfo']: audioProfile = movie['movieFile']['mediaInfo']['audioProfile']
if 'audioAdditionalFeatures' in movie['movieFile']['mediaInfo']: audioAdditionalFeatures = movie['movieFile']['mediaInfo']['audioAdditionalFeatures']
audioCodec = RadarrFormatAudioCodec(audioFormat, audioCodecID, audioProfile, audioAdditionalFeatures)
else:
videoCodec = None
audioCodec = None
audio_language = []
if radarr_version.startswith('0'):
if 'mediaInfo' in movie['movieFile']:
if 'audioLanguages' in movie['movieFile']['mediaInfo']:
audio_languages_list = movie['movieFile']['mediaInfo']['audioLanguages'].split('/')
if len(audio_languages_list):
for audio_language_list in audio_languages_list:
audio_language.append(audio_language_list.strip())
if not audio_language:
audio_language = profile_id_to_language(movie['qualityProfileId'], audio_profiles)
else:
if 'languages' in movie['movieFile'] and len(movie['movieFile']['languages']):
for item in movie['movieFile']['languages']:
if isinstance(item, dict):
if 'name' in item:
audio_language.append(item['name'])
tags = [d['label'] for d in tagsDict if d['id'] in movie['tags']]
if movie['movieFile']['size'] > 20480:
# Add movies in radarr to current movies list
current_movies_radarr.append(str(movie['tmdbId']))
if str(movie['tmdbId']) in current_movies_db_list:
movies_to_update.append({'radarrId': int(movie["id"]),
'title': movie["title"],
'path': movie["path"] + separator + movie['movieFile']['relativePath'],
'tmdbId': str(movie["tmdbId"]),
'poster': poster,
'fanart': fanart,
'audio_language': str(audio_language),
'sceneName': sceneName,
'monitored': str(bool(movie['monitored'])),
'year': str(movie['year']),
'sortTitle': movie['sortTitle'],
'alternativeTitles': alternativeTitles,
'format': format,
'resolution': resolution,
'video_codec': videoCodec,
'audio_codec': audioCodec,
'overview': overview,
'imdbId': imdbId,
'movie_file_id': int(movie['movieFile']['id']),
'tags': str(tags),
'file_size': movie['movieFile']['size']})
movies_to_update.append(movieParser(movie, action='update',
radarr_version=radarr_version,
tags_dict=tagsDict,
movie_default_profile=movie_default_profile,
audio_profiles=audio_profiles))
else:
movies_to_add.append({'radarrId': int(movie["id"]),
'title': movie["title"],
'path': movie["path"] + separator + movie['movieFile']['relativePath'],
'tmdbId': str(movie["tmdbId"]),
'subtitles': '[]',
'overview': overview,
'poster': poster,
'fanart': fanart,
'audio_language': str(audio_language),
'sceneName': sceneName,
'monitored': str(bool(movie['monitored'])),
'sortTitle': movie['sortTitle'],
'year': str(movie['year']),
'alternativeTitles': alternativeTitles,
'format': format,
'resolution': resolution,
'video_codec': videoCodec,
'audio_codec': audioCodec,
'imdbId': imdbId,
'movie_file_id': int(movie['movieFile']['id']),
'tags': str(tags),
'profileId': movie_default_profile,
'file_size': movie['movieFile']['size']})
else:
logging.error(
'BAZARR Radarr returned a movie without a file path: ' + movie["path"] + separator +
movie['movieFile']['relativePath'])
movies_to_add.append(movieParser(movie, action='insert',
radarr_version=radarr_version,
tags_dict=tagsDict,
movie_default_profile=movie_default_profile,
audio_profiles=audio_profiles))
# Remove old movies from DB
removed_movies = list(set(current_movies_db_list) - set(current_movies_radarr))
@ -283,6 +142,95 @@ def update_movies():
logging.debug("BAZARR More than 5 movies were added during this sync then we wont search for subtitles.")
def update_one_movie(movie_id, action):
logging.debug('BAZARR syncing this specific movie from Radarr: {}'.format(movie_id))
# Check if there's a row in database for this movie ID
existing_movie = database.execute('SELECT path FROM table_movies WHERE radarrId = ?', (movie_id,), only_one=True)
# Remove movie from DB
if action == 'deleted':
if existing_movie:
database.execute("DELETE FROM table_movies WHERE radarrId=?", (movie_id,))
event_stream(type='movie', action='delete', payload=int(movie_id))
logging.debug('BAZARR deleted this movie from the database:{}'.format(path_mappings.path_replace_movie(
existing_movie['path'])))
return
radarr_version = get_radarr_version()
movie_default_enabled = settings.general.getboolean('movie_default_enabled')
if movie_default_enabled is True:
movie_default_profile = settings.general.movie_default_profile
if movie_default_profile == '':
movie_default_profile = None
else:
movie_default_profile = None
audio_profiles = get_profile_list()
tagsDict = get_tags()
try:
# Get movie data from radarr api
movie = None
movie_data = get_movies_from_radarr_api(radarr_version=radarr_version, url=url_radarr(),
apikey_radarr=settings.radarr.apikey, radarr_id=movie_id)
if not movie_data:
return
else:
if action == 'updated' and existing_movie:
movie = movieParser(movie_data, action='update', radarr_version=radarr_version,
tags_dict=tagsDict, movie_default_profile=movie_default_profile,
audio_profiles=audio_profiles)
elif action == 'updated' and not existing_movie:
movie = movieParser(movie_data, action='insert', radarr_version=radarr_version,
tags_dict=tagsDict, movie_default_profile=movie_default_profile,
audio_profiles=audio_profiles)
except Exception:
logging.debug('BAZARR cannot get movie returned by SignalR feed from Radarr API.')
return
# Drop useless events
if not movie and not existing_movie:
return
# Remove movie from DB
if not movie and existing_movie:
database.execute("DELETE FROM table_movies WHERE radarrId=?", (movie_id,))
event_stream(type='movie', action='delete', payload=int(movie_id))
logging.debug('BAZARR deleted this movie from the database:{}'.format(path_mappings.path_replace_movie(
existing_movie['path'])))
return
# Update existing movie in DB
elif movie and existing_movie:
query = dict_converter.convert(movie)
database.execute('''UPDATE table_movies SET ''' + query.keys_update + ''' WHERE radarrId = ?''',
query.values + (movie['radarrId'],))
event_stream(type='movie', action='update', payload=int(movie_id))
logging.debug('BAZARR updated this movie into the database:{}'.format(path_mappings.path_replace_movie(
movie['path'])))
# Insert new movie in DB
elif movie and not existing_movie:
query = dict_converter.convert(movie)
database.execute('''INSERT OR IGNORE INTO table_movies(''' + query.keys_insert + ''') VALUES(''' +
query.question_marks + ''')''', query.values)
event_stream(type='movie', action='update', payload=int(movie_id))
logging.debug('BAZARR inserted this movie into the database:{}'.format(path_mappings.path_replace_movie(
movie['path'])))
# Storing existing subtitles
logging.debug('BAZARR storing subtitles for this movie: {}'.format(path_mappings.path_replace_movie(
movie['path'])))
store_subtitles_movie(movie['path'], path_mappings.path_replace_movie(movie['path']))
# Downloading missing subtitles
logging.debug('BAZARR downloading missing subtitles for this movie: {}'.format(path_mappings.path_replace_movie(
movie['path'])))
movies_download_subtitles(movie_id)
def get_profile_list():
apikey_radarr = settings.radarr.apikey
radarr_version = get_radarr_version()
@ -373,11 +321,11 @@ def get_tags():
apikey_radarr = settings.radarr.apikey
tagsDict = []
# Get tags data from Sonarr
url_sonarr_api_series = url_radarr() + "/api/tag?apikey=" + apikey_radarr
# Get tags data from Radarr
url_radarr_api_series = url_radarr() + "/api/tag?apikey=" + apikey_radarr
try:
tagsDict = requests.get(url_sonarr_api_series, timeout=60, verify=False, headers=headers)
tagsDict = requests.get(url_radarr_api_series, timeout=60, verify=False, headers=headers)
except requests.exceptions.ConnectionError:
logging.exception("BAZARR Error trying to get tags from Radarr. Connection Error.")
return []
@ -389,3 +337,183 @@ def get_tags():
return []
else:
return tagsDict.json()
def movieParser(movie, action, radarr_version, tags_dict, movie_default_profile, audio_profiles):
if 'movieFile' in movie:
# Detect file separator
if movie['path'][0] == "/":
separator = "/"
else:
separator = "\\"
try:
overview = str(movie['overview'])
except:
overview = ""
try:
poster_big = movie['images'][0]['url']
poster = os.path.splitext(poster_big)[0] + '-500' + os.path.splitext(poster_big)[1]
except:
poster = ""
try:
fanart = movie['images'][1]['url']
except:
fanart = ""
if 'sceneName' in movie['movieFile']:
sceneName = movie['movieFile']['sceneName']
else:
sceneName = None
alternativeTitles = None
if radarr_version.startswith('0'):
if 'alternativeTitles' in movie:
alternativeTitles = str([item['title'] for item in movie['alternativeTitles']])
else:
if 'alternateTitles' in movie:
alternativeTitles = str([item['title'] for item in movie['alternateTitles']])
if 'imdbId' in movie:
imdbId = movie['imdbId']
else:
imdbId = None
try:
format, resolution = movie['movieFile']['quality']['quality']['name'].split('-')
except:
format = movie['movieFile']['quality']['quality']['name']
try:
resolution = str(movie['movieFile']['quality']['quality']['resolution']) + 'p'
except:
resolution = None
if 'mediaInfo' in movie['movieFile']:
videoFormat = videoCodecID = videoProfile = videoCodecLibrary = None
if radarr_version.startswith('0'):
if 'videoFormat' in movie['movieFile']['mediaInfo']: videoFormat = \
movie['movieFile']['mediaInfo']['videoFormat']
else:
if 'videoCodec' in movie['movieFile']['mediaInfo']: videoFormat = \
movie['movieFile']['mediaInfo']['videoCodec']
if 'videoCodecID' in movie['movieFile']['mediaInfo']: videoCodecID = \
movie['movieFile']['mediaInfo']['videoCodecID']
if 'videoProfile' in movie['movieFile']['mediaInfo']: videoProfile = \
movie['movieFile']['mediaInfo']['videoProfile']
if 'videoCodecLibrary' in movie['movieFile']['mediaInfo']: videoCodecLibrary = \
movie['movieFile']['mediaInfo']['videoCodecLibrary']
videoCodec = RadarrFormatVideoCodec(videoFormat, videoCodecID, videoCodecLibrary)
audioFormat = audioCodecID = audioProfile = audioAdditionalFeatures = None
if radarr_version.startswith('0'):
if 'audioFormat' in movie['movieFile']['mediaInfo']: audioFormat = \
movie['movieFile']['mediaInfo']['audioFormat']
else:
if 'audioCodec' in movie['movieFile']['mediaInfo']: audioFormat = \
movie['movieFile']['mediaInfo']['audioCodec']
if 'audioCodecID' in movie['movieFile']['mediaInfo']: audioCodecID = \
movie['movieFile']['mediaInfo']['audioCodecID']
if 'audioProfile' in movie['movieFile']['mediaInfo']: audioProfile = \
movie['movieFile']['mediaInfo']['audioProfile']
if 'audioAdditionalFeatures' in movie['movieFile']['mediaInfo']: audioAdditionalFeatures = \
movie['movieFile']['mediaInfo']['audioAdditionalFeatures']
audioCodec = RadarrFormatAudioCodec(audioFormat, audioCodecID, audioProfile,
audioAdditionalFeatures)
else:
videoCodec = None
audioCodec = None
audio_language = []
if radarr_version.startswith('0'):
if 'mediaInfo' in movie['movieFile']:
if 'audioLanguages' in movie['movieFile']['mediaInfo']:
audio_languages_list = movie['movieFile']['mediaInfo']['audioLanguages'].split('/')
if len(audio_languages_list):
for audio_language_list in audio_languages_list:
audio_language.append(audio_language_list.strip())
if not audio_language:
audio_language = profile_id_to_language(movie['qualityProfileId'], audio_profiles)
else:
if 'languages' in movie['movieFile'] and len(movie['movieFile']['languages']):
for item in movie['movieFile']['languages']:
if isinstance(item, dict):
if 'name' in item:
audio_language.append(item['name'])
tags = [d['label'] for d in tags_dict if d['id'] in movie['tags']]
if action == 'update':
return {'radarrId': int(movie["id"]),
'title': movie["title"],
'path': movie["path"] + separator + movie['movieFile']['relativePath'],
'tmdbId': str(movie["tmdbId"]),
'poster': poster,
'fanart': fanart,
'audio_language': str(audio_language),
'sceneName': sceneName,
'monitored': str(bool(movie['monitored'])),
'year': str(movie['year']),
'sortTitle': movie['sortTitle'],
'alternativeTitles': alternativeTitles,
'format': format,
'resolution': resolution,
'video_codec': videoCodec,
'audio_codec': audioCodec,
'overview': overview,
'imdbId': imdbId,
'movie_file_id': int(movie['movieFile']['id']),
'tags': str(tags),
'file_size': movie['movieFile']['size']}
else:
return {'radarrId': int(movie["id"]),
'title': movie["title"],
'path': movie["path"] + separator + movie['movieFile']['relativePath'],
'tmdbId': str(movie["tmdbId"]),
'subtitles': '[]',
'overview': overview,
'poster': poster,
'fanart': fanart,
'audio_language': str(audio_language),
'sceneName': sceneName,
'monitored': str(bool(movie['monitored'])),
'sortTitle': movie['sortTitle'],
'year': str(movie['year']),
'alternativeTitles': alternativeTitles,
'format': format,
'resolution': resolution,
'video_codec': videoCodec,
'audio_codec': audioCodec,
'imdbId': imdbId,
'movie_file_id': int(movie['movieFile']['id']),
'tags': str(tags),
'profileId': movie_default_profile,
'file_size': movie['movieFile']['size']}
def get_movies_from_radarr_api(radarr_version, url, apikey_radarr, radarr_id=None):
if radarr_version.startswith('0'):
url_radarr_api_movies = url + "/api/movie" + ("/{}".format(radarr_id) if radarr_id else "") + "?apikey=" + \
apikey_radarr
else:
url_radarr_api_movies = url + "/api/v3/movie" + ("/{}".format(radarr_id) if radarr_id else "") + "?apikey=" + \
apikey_radarr
try:
r = requests.get(url_radarr_api_movies, timeout=60, verify=False, headers=headers)
if r.status_code == 404:
return
r.raise_for_status()
except requests.exceptions.HTTPError as errh:
logging.exception("BAZARR Error trying to get movies from Radarr. Http error.")
return
except requests.exceptions.ConnectionError as errc:
logging.exception("BAZARR Error trying to get movies from Radarr. Connection Error.")
return
except requests.exceptions.Timeout as errt:
logging.exception("BAZARR Error trying to get movies from Radarr. Timeout Error.")
return
except requests.exceptions.RequestException as err:
logging.exception("BAZARR Error trying to get movies from Radarr.")
return
else:
return r.json()

@ -35,131 +35,144 @@ def update_series():
tagsDict = get_tags()
# Get shows data from Sonarr
url_sonarr_api_series = url_sonarr() + "/api/series?apikey=" + apikey_sonarr
try:
r = requests.get(url_sonarr_api_series, timeout=60, verify=False, headers=headers)
r.raise_for_status()
except requests.exceptions.HTTPError:
logging.exception("BAZARR Error trying to get series from Sonarr. Http error.")
return
except requests.exceptions.ConnectionError:
logging.exception("BAZARR Error trying to get series from Sonarr. Connection Error.")
return
except requests.exceptions.Timeout:
logging.exception("BAZARR Error trying to get series from Sonarr. Timeout Error.")
return
except requests.exceptions.RequestException:
logging.exception("BAZARR Error trying to get series from Sonarr.")
series = get_series_from_sonarr_api(url=url_sonarr(), apikey_sonarr=apikey_sonarr)
if not series:
return
else:
# Get current shows in DB
current_shows_db = database.execute("SELECT sonarrSeriesId FROM table_shows")
current_shows_db_list = [x['sonarrSeriesId'] for x in current_shows_db]
current_shows_sonarr = []
series_to_update = []
series_to_add = []
for show in series:
# Add shows in Sonarr to current shows list
current_shows_sonarr.append(show['id'])
if show['id'] in current_shows_db_list:
series_to_update.append(seriesParser(show, action='update', sonarr_version=sonarr_version,
tags_dict=tagsDict, serie_default_profile=serie_default_profile,
audio_profiles=audio_profiles))
else:
series_to_add.append(seriesParser(show, action='insert', sonarr_version=sonarr_version,
tags_dict=tagsDict, serie_default_profile=serie_default_profile,
audio_profiles=audio_profiles))
# Remove old series from DB
removed_series = list(set(current_shows_db_list) - set(current_shows_sonarr))
for series in removed_series:
database.execute("DELETE FROM table_shows WHERE sonarrSeriesId=?",(series,))
event_stream(type='series', action='delete', payload=series)
# Update existing series in DB
series_in_db_list = []
series_in_db = database.execute("SELECT title, path, tvdbId, sonarrSeriesId, overview, poster, fanart, "
"audio_language, sortTitle, year, alternateTitles, tags, seriesType, imdbId "
"FROM table_shows")
for item in series_in_db:
series_in_db_list.append(item)
series_to_update_list = [i for i in series_to_update if i not in series_in_db_list]
for updated_series in series_to_update_list:
query = dict_converter.convert(updated_series)
database.execute('''UPDATE table_shows SET ''' + query.keys_update + ''' WHERE sonarrSeriesId = ?''',
query.values + (updated_series['sonarrSeriesId'],))
event_stream(type='series', payload=updated_series['sonarrSeriesId'])
# Insert new series in DB
for added_series in series_to_add:
query = dict_converter.convert(added_series)
result = database.execute(
'''INSERT OR IGNORE INTO table_shows(''' + query.keys_insert + ''') VALUES(''' +
query.question_marks + ''')''', query.values)
if result:
list_missing_subtitles(no=added_series['sonarrSeriesId'])
else:
logging.debug('BAZARR unable to insert this series into the database:',
path_mappings.path_replace(added_series['path']))
event_stream(type='series', action='update', payload=added_series['sonarrSeriesId'])
logging.debug('BAZARR All series synced from Sonarr into database.')
def update_one_series(series_id, action):
logging.debug('BAZARR syncing this specific series from RSonarr: {}'.format(series_id))
# Check if there's a row in database for this series ID
existing_series = database.execute('SELECT path FROM table_shows WHERE sonarrSeriesId = ?', (series_id,),
only_one=True)
# Get current shows in DB
current_shows_db = database.execute("SELECT sonarrSeriesId FROM table_shows")
current_shows_db_list = [x['sonarrSeriesId'] for x in current_shows_db]
current_shows_sonarr = []
series_to_update = []
series_to_add = []
series_list_length = len(r.json())
for i, show in enumerate(r.json(), 1):
overview = show['overview'] if 'overview' in show else ''
poster = ''
fanart = ''
for image in show['images']:
if image['coverType'] == 'poster':
poster_big = image['url'].split('?')[0]
poster = os.path.splitext(poster_big)[0] + '-250' + os.path.splitext(poster_big)[1]
if image['coverType'] == 'fanart':
fanart = image['url'].split('?')[0]
alternate_titles = None
if show['alternateTitles'] is not None:
alternate_titles = str([item['title'] for item in show['alternateTitles']])
audio_language = []
if sonarr_version.startswith('2'):
audio_language = profile_id_to_language(show['qualityProfileId'], audio_profiles)
else:
audio_language = profile_id_to_language(show['languageProfileId'], audio_profiles)
tags = [d['label'] for d in tagsDict if d['id'] in show['tags']]
imdbId = show['imdbId'] if 'imdbId' in show else None
# Add shows in Sonarr to current shows list
current_shows_sonarr.append(show['id'])
if show['id'] in current_shows_db_list:
series_to_update.append({'title': show["title"],
'path': show["path"],
'tvdbId': int(show["tvdbId"]),
'sonarrSeriesId': int(show["id"]),
'overview': overview,
'poster': poster,
'fanart': fanart,
'audio_language': str(audio_language),
'sortTitle': show['sortTitle'],
'year': str(show['year']),
'alternateTitles': alternate_titles,
'tags': str(tags),
'seriesType': show['seriesType'],
'imdbId': imdbId})
else:
series_to_add.append({'title': show["title"],
'path': show["path"],
'tvdbId': show["tvdbId"],
'sonarrSeriesId': show["id"],
'overview': overview,
'poster': poster,
'fanart': fanart,
'audio_language': str(audio_language),
'sortTitle': show['sortTitle'],
'year': str(show['year']),
'alternateTitles': alternate_titles,
'tags': str(tags),
'seriesType': show['seriesType'],
'imdbId': imdbId,
'profileId': serie_default_profile})
# Remove old series from DB
removed_series = list(set(current_shows_db_list) - set(current_shows_sonarr))
for series in removed_series:
database.execute("DELETE FROM table_shows WHERE sonarrSeriesId=?",(series,))
event_stream(type='series', action='delete', payload=series)
sonarr_version = get_sonarr_version()
serie_default_enabled = settings.general.getboolean('serie_default_enabled')
# Update existing series in DB
series_in_db_list = []
series_in_db = database.execute("SELECT title, path, tvdbId, sonarrSeriesId, overview, poster, fanart, "
"audio_language, sortTitle, year, alternateTitles, tags, seriesType, imdbId FROM table_shows")
if serie_default_enabled is True:
serie_default_profile = settings.general.serie_default_profile
if serie_default_profile == '':
serie_default_profile = None
else:
serie_default_profile = None
for item in series_in_db:
series_in_db_list.append(item)
audio_profiles = get_profile_list()
tagsDict = get_tags()
series_to_update_list = [i for i in series_to_update if i not in series_in_db_list]
try:
# Get series data from sonarr api
series = None
try:
series_data = get_series_from_sonarr_api(url=url_sonarr(), apikey_sonarr=settings.sonarr.apikey,
sonarr_series_id=series_id)
except requests.exceptions.HTTPError:
database.execute("DELETE FROM table_shows WHERE sonarrSeriesId=?", (series_id,))
event_stream(type='series', action='delete', payload=int(series_id))
return
if not series_data:
return
else:
if action == 'updated' and existing_series:
series = seriesParser(series_data, action='update', sonarr_version=sonarr_version,
tags_dict=tagsDict, serie_default_profile=serie_default_profile,
audio_profiles=audio_profiles)
elif action == 'updated' and not existing_series:
series = seriesParser(series_data, action='insert', sonarr_version=sonarr_version,
tags_dict=tagsDict, serie_default_profile=serie_default_profile,
audio_profiles=audio_profiles)
except Exception:
logging.debug('BAZARR cannot parse series returned by SignalR feed.')
return
for updated_series in series_to_update_list:
query = dict_converter.convert(updated_series)
# Remove series from DB
if action == 'deleted':
database.execute("DELETE FROM table_shows WHERE sonarrSeriesId=?", (series_id,))
event_stream(type='series', action='delete', payload=int(series_id))
logging.debug('BAZARR deleted this series from the database:{}'.format(path_mappings.path_replace(
existing_series['path'])))
return
# Update existing series in DB
elif action == 'updated' and existing_series:
query = dict_converter.convert(series)
database.execute('''UPDATE table_shows SET ''' + query.keys_update + ''' WHERE sonarrSeriesId = ?''',
query.values + (updated_series['sonarrSeriesId'],))
event_stream(type='series', payload=updated_series['sonarrSeriesId'])
query.values + (series['sonarrSeriesId'],))
event_stream(type='series', action='update', payload=int(series_id))
logging.debug('BAZARR updated this series into the database:{}'.format(path_mappings.path_replace(
series['path'])))
# Insert new series in DB
for added_series in series_to_add:
query = dict_converter.convert(added_series)
result = database.execute(
'''INSERT OR IGNORE INTO table_shows(''' + query.keys_insert + ''') VALUES(''' +
query.question_marks + ''')''', query.values)
if result:
list_missing_subtitles(no=added_series['sonarrSeriesId'])
else:
logging.debug('BAZARR unable to insert this series into the database:',
path_mappings.path_replace(added_series['path']))
event_stream(type='series', series=added_series['sonarrSeriesId'])
logging.debug('BAZARR All series synced from Sonarr into database.')
elif action == 'updated' and not existing_series:
query = dict_converter.convert(series)
database.execute('''INSERT OR IGNORE INTO table_shows(''' + query.keys_insert + ''') VALUES(''' +
query.question_marks + ''')''', query.values)
event_stream(type='series', action='update', payload=int(series_id))
logging.debug('BAZARR inserted this series into the database:{}'.format(path_mappings.path_replace(
series['path'])))
def get_profile_list():
@ -224,3 +237,86 @@ def get_tags():
return []
else:
return tagsDict.json()
def seriesParser(show, action, sonarr_version, tags_dict, serie_default_profile, audio_profiles):
overview = show['overview'] if 'overview' in show else ''
poster = ''
fanart = ''
for image in show['images']:
if image['coverType'] == 'poster':
poster_big = image['url'].split('?')[0]
poster = os.path.splitext(poster_big)[0] + '-250' + os.path.splitext(poster_big)[1]
if image['coverType'] == 'fanart':
fanart = image['url'].split('?')[0]
alternate_titles = None
if show['alternateTitles'] is not None:
alternate_titles = str([item['title'] for item in show['alternateTitles']])
audio_language = []
if sonarr_version.startswith('2'):
audio_language = profile_id_to_language(show['qualityProfileId'], audio_profiles)
else:
audio_language = profile_id_to_language(show['languageProfileId'], audio_profiles)
tags = [d['label'] for d in tags_dict if d['id'] in show['tags']]
imdbId = show['imdbId'] if 'imdbId' in show else None
if action == 'update':
return {'title': show["title"],
'path': show["path"],
'tvdbId': int(show["tvdbId"]),
'sonarrSeriesId': int(show["id"]),
'overview': overview,
'poster': poster,
'fanart': fanart,
'audio_language': str(audio_language),
'sortTitle': show['sortTitle'],
'year': str(show['year']),
'alternateTitles': alternate_titles,
'tags': str(tags),
'seriesType': show['seriesType'],
'imdbId': imdbId}
else:
return {'title': show["title"],
'path': show["path"],
'tvdbId': show["tvdbId"],
'sonarrSeriesId': show["id"],
'overview': overview,
'poster': poster,
'fanart': fanart,
'audio_language': str(audio_language),
'sortTitle': show['sortTitle'],
'year': str(show['year']),
'alternateTitles': alternate_titles,
'tags': str(tags),
'seriesType': show['seriesType'],
'imdbId': imdbId,
'profileId': serie_default_profile}
def get_series_from_sonarr_api(url, apikey_sonarr, sonarr_series_id=None):
url_sonarr_api_series = url + "/api/series" + ("/{}".format(sonarr_series_id) if sonarr_series_id else "") + \
"?apikey=" + apikey_sonarr
try:
r = requests.get(url_sonarr_api_series, timeout=60, verify=False, headers=headers)
r.raise_for_status()
except requests.exceptions.HTTPError as e:
if e.response.status_code:
raise requests.exceptions.HTTPError
logging.exception("BAZARR Error trying to get series from Sonarr. Http error.")
return
except requests.exceptions.ConnectionError:
logging.exception("BAZARR Error trying to get series from Sonarr. Connection Error.")
return
except requests.exceptions.Timeout:
logging.exception("BAZARR Error trying to get series from Sonarr. Timeout Error.")
return
except requests.exceptions.RequestException:
logging.exception("BAZARR Error trying to get series from Sonarr.")
return
else:
return r.json()

@ -103,8 +103,12 @@ def configure_logging(debug=False):
logging.getLogger("ffsubsync.speech_transformers").setLevel(logging.ERROR)
logging.getLogger("ffsubsync.ffsubsync").setLevel(logging.ERROR)
logging.getLogger("srt").setLevel(logging.ERROR)
logging.getLogger("SignalRCoreClient").setLevel(logging.CRITICAL)
logging.getLogger("websocket").setLevel(logging.CRITICAL)
logging.getLogger("geventwebsocket.handler").setLevel(logging.WARNING)
logging.getLogger("geventwebsocket.handler").setLevel(logging.WARNING)
logging.getLogger("engineio.server").setLevel(logging.WARNING)
logging.getLogger("knowit").setLevel(logging.CRITICAL)
logging.getLogger("enzyme").setLevel(logging.CRITICAL)
logging.getLogger("guessit").setLevel(logging.WARNING)

@ -1,5 +1,13 @@
# coding=utf-8
# Gevent monkey patch if gevent available. If not, it will be installed on during the init process.
try:
from gevent import monkey
except ImportError:
pass
else:
monkey.patch_all()
import os
bazarr_version = 'unknown'
@ -12,14 +20,9 @@ if os.path.isfile(version_file):
os.environ["BAZARR_VERSION"] = bazarr_version
import gc
import libs
import hashlib
import calendar
from get_args import args
from logger import empty_log
from config import settings, url_sonarr, url_radarr, configure_proxy_func, base_url
from init import *
@ -28,13 +31,14 @@ from database import database
from notifier import update_notifier
from urllib.parse import unquote
from get_languages import load_language_in_db, language_from_alpha2, alpha3_from_alpha2
from get_languages import load_language_in_db
from flask import make_response, request, redirect, abort, render_template, Response, session, flash, url_for, \
send_file, stream_with_context
from get_series import *
from get_episodes import *
from get_movies import *
from signalr_client import sonarr_signalr_client, radarr_signalr_client
from check_update import apply_update, check_if_new_update, check_releases
from server import app, webserver
@ -186,5 +190,11 @@ def proxy(protocol, url):
return dict(status=False, error=result.raise_for_status())
if settings.general.getboolean('use_sonarr'):
sonarr_signalr_client.start()
if settings.general.getboolean('use_radarr'):
radarr_signalr_client.start()
if __name__ == "__main__":
webserver.start()

@ -0,0 +1,153 @@
# coding=utf-8
import logging
import gevent
import threading
from requests import Session
from signalr import Connection
from requests.exceptions import ConnectionError
from signalrcore.hub_connection_builder import HubConnectionBuilder
from config import settings, url_sonarr, url_radarr
from get_episodes import sync_episodes, sync_one_episode
from get_series import update_series, update_one_series
from get_movies import update_movies, update_one_movie
from scheduler import scheduler
from utils import get_sonarr_version
class SonarrSignalrClient(threading.Thread):
def __init__(self):
super(SonarrSignalrClient, self).__init__()
self.stopped = True
self.apikey_sonarr = None
self.session = Session()
self.connection = None
def stop(self):
self.connection.close()
self.stopped = True
logging.info('BAZARR SignalR client for Sonarr is now disconnected.')
def restart(self):
if not self.stopped:
self.stop()
if settings.general.getboolean('use_sonarr'):
self.run()
def run(self):
if get_sonarr_version().startswith('2.'):
logging.warning('BAZARR can only sync from Sonarr v3 SignalR feed to get real-time update. You should '
'consider upgrading.')
return
self.apikey_sonarr = settings.sonarr.apikey
self.connection = Connection(url_sonarr() + "/signalr", self.session)
self.connection.qs = {'apikey': self.apikey_sonarr}
sonarr_hub = self.connection.register_hub('') # Sonarr doesn't use named hub
sonarr_method = ['series', 'episode']
for item in sonarr_method:
sonarr_hub.client.on(item, dispatcher)
while True:
if not self.stopped:
return
if self.connection.started:
gevent.sleep(5)
else:
try:
logging.debug('BAZARR connecting to Sonarr SignalR feed...')
self.connection.start()
except ConnectionError:
logging.error('BAZARR connection to Sonarr SignalR feed has been lost. Reconnecting...')
gevent.sleep(15)
else:
self.stopped = False
logging.info('BAZARR SignalR client for Sonarr is connected and waiting for events.')
scheduler.execute_job_now('update_series')
scheduler.execute_job_now('sync_episodes')
gevent.sleep()
class RadarrSignalrClient(threading.Thread):
def __init__(self):
super(RadarrSignalrClient, self).__init__()
self.stopped = True
self.apikey_radarr = None
self.connection = None
def stop(self):
self.connection.stop()
self.stopped = True
logging.info('BAZARR SignalR client for Radarr is now disconnected.')
def restart(self):
if not self.stopped:
self.stop()
if settings.general.getboolean('use_radarr'):
self.run()
def run(self):
self.apikey_radarr = settings.radarr.apikey
self.connection = HubConnectionBuilder() \
.with_url(url_radarr() + "/signalr/messages?access_token={}".format(self.apikey_radarr),
options={
"verify_ssl": False
}).build()
self.connection.on_open(lambda: logging.debug("BAZARR SignalR client for Radarr is connected."))
self.connection.on_close(lambda: logging.debug("BAZARR SignalR client for Radarr is disconnected."))
self.connection.on_error(lambda data: logging.debug(f"BAZARR SignalR client for Radarr: An exception was thrown"
f" closed{data.error}"))
self.connection.on("receiveMessage", dispatcher)
while True:
if not self.stopped:
return
if self.connection.transport.state.value == 4:
# 0: 'connecting', 1: 'connected', 2: 'reconnecting', 4: 'disconnected'
try:
logging.debug('BAZARR connecting to Radarr SignalR feed...')
self.connection.start()
except ConnectionError:
logging.error('BAZARR connection to Radarr SignalR feed has been lost. Reconnecting...')
gevent.sleep(15)
pass
else:
self.stopped = False
logging.info('BAZARR SignalR client for Radarr is connected and waiting for events.')
scheduler.execute_job_now('update_movies')
gevent.sleep()
else:
gevent.sleep(5)
def dispatcher(data):
topic = media_id = action = None
if isinstance(data, dict):
topic = data['name']
try:
media_id = data['body']['resource']['id']
action = data['body']['action']
except KeyError:
return
elif isinstance(data, list):
topic = data[0]['name']
try:
media_id = data[0]['body']['resource']['id']
action = data[0]['body']['action']
except KeyError:
return
if topic == 'series':
update_one_series(series_id=media_id, action=action)
elif topic == 'episode':
sync_one_episode(episode_id=media_id)
elif topic == 'movie':
update_one_movie(movie_id=media_id, action=action)
else:
return
sonarr_signalr_client = SonarrSignalrClient()
radarr_signalr_client = RadarrSignalrClient()

@ -101,7 +101,7 @@ const Table: FunctionComponent<Props> = ({ episodes, profile }) => {
{
Header: "Subtitles",
accessor: "missing_subtitles",
Cell: ({ row }) => {
Cell: ({ row, loose }) => {
const episode = row.original;
const seriesid = episode.sonarrSeriesId;

@ -1,20 +1,15 @@
export const seriesSyncOptions: SelectorOption<number>[] = [
{ label: "1 Minute", value: 1 },
{ label: "5 Minutes", value: 5 },
{ label: "15 Minutes", value: 15 },
{ label: "1 Hour", value: 60 },
{ label: "3 Hours", value: 180 },
];
export const episodesSyncOptions: SelectorOption<number>[] = [
{ label: "5 Minutes", value: 5 },
{ label: "15 Minutes", value: 15 },
{ label: "1 Hour", value: 60 },
{ label: "3 Hours", value: 180 },
{ label: "6 Hours", value: 360 },
{ label: "12 Hours", value: 720 },
{ label: "24 Hours", value: 1440 },
];
export const moviesSyncOptions = episodesSyncOptions;
export const episodesSyncOptions = seriesSyncOptions;
export const moviesSyncOptions = seriesSyncOptions;
export const diskUpdateOptions: SelectorOption<string>[] = [
{ label: "Manually", value: "Manually" },

@ -0,0 +1,54 @@
# coding: utf-8
from ._version import version
from .exceptions import *
from .ext import ExtType, Timestamp
import os
import sys
if os.environ.get("MSGPACK_PUREPYTHON") or sys.version_info[0] == 2:
from .fallback import Packer, unpackb, Unpacker
else:
try:
from ._cmsgpack import Packer, unpackb, Unpacker
except ImportError:
from .fallback import Packer, unpackb, Unpacker
def pack(o, stream, **kwargs):
"""
Pack object `o` and write it to `stream`
See :class:`Packer` for options.
"""
packer = Packer(**kwargs)
stream.write(packer.pack(o))
def packb(o, **kwargs):
"""
Pack object `o` and return packed bytes
See :class:`Packer` for options.
"""
return Packer(**kwargs).pack(o)
def unpack(stream, **kwargs):
"""
Unpack an object from `stream`.
Raises `ExtraData` when `stream` contains extra bytes.
See :class:`Unpacker` for options.
"""
data = stream.read()
return unpackb(data, **kwargs)
# alias for compatibility to simplejson/marshal/pickle.
load = unpack
loads = unpackb
dump = pack
dumps = packb

@ -0,0 +1,11 @@
# coding: utf-8
#cython: embedsignature=True, c_string_encoding=ascii, language_level=3
from cpython.datetime cimport import_datetime, datetime_new
import_datetime()
import datetime
cdef object utc = datetime.timezone.utc
cdef object epoch = datetime_new(1970, 1, 1, 0, 0, 0, 0, tz=utc)
include "_packer.pyx"
include "_unpacker.pyx"

@ -0,0 +1,372 @@
# coding: utf-8
from cpython cimport *
from cpython.bytearray cimport PyByteArray_Check, PyByteArray_CheckExact
from cpython.datetime cimport (
PyDateTime_CheckExact, PyDelta_CheckExact,
datetime_tzinfo, timedelta_days, timedelta_seconds, timedelta_microseconds,
)
cdef ExtType
cdef Timestamp
from .ext import ExtType, Timestamp
cdef extern from "Python.h":
int PyMemoryView_Check(object obj)
char* PyUnicode_AsUTF8AndSize(object obj, Py_ssize_t *l) except NULL
cdef extern from "pack.h":
struct msgpack_packer:
char* buf
size_t length
size_t buf_size
bint use_bin_type
int msgpack_pack_int(msgpack_packer* pk, int d)
int msgpack_pack_nil(msgpack_packer* pk)
int msgpack_pack_true(msgpack_packer* pk)
int msgpack_pack_false(msgpack_packer* pk)
int msgpack_pack_long(msgpack_packer* pk, long d)
int msgpack_pack_long_long(msgpack_packer* pk, long long d)
int msgpack_pack_unsigned_long_long(msgpack_packer* pk, unsigned long long d)
int msgpack_pack_float(msgpack_packer* pk, float d)
int msgpack_pack_double(msgpack_packer* pk, double d)
int msgpack_pack_array(msgpack_packer* pk, size_t l)
int msgpack_pack_map(msgpack_packer* pk, size_t l)
int msgpack_pack_raw(msgpack_packer* pk, size_t l)
int msgpack_pack_bin(msgpack_packer* pk, size_t l)
int msgpack_pack_raw_body(msgpack_packer* pk, char* body, size_t l)
int msgpack_pack_ext(msgpack_packer* pk, char typecode, size_t l)
int msgpack_pack_timestamp(msgpack_packer* x, long long seconds, unsigned long nanoseconds);
int msgpack_pack_unicode(msgpack_packer* pk, object o, long long limit)
cdef extern from "buff_converter.h":
object buff_to_buff(char *, Py_ssize_t)
cdef int DEFAULT_RECURSE_LIMIT=511
cdef long long ITEM_LIMIT = (2**32)-1
cdef inline int PyBytesLike_Check(object o):
return PyBytes_Check(o) or PyByteArray_Check(o)
cdef inline int PyBytesLike_CheckExact(object o):
return PyBytes_CheckExact(o) or PyByteArray_CheckExact(o)
cdef class Packer(object):
"""
MessagePack Packer
Usage::
packer = Packer()
astream.write(packer.pack(a))
astream.write(packer.pack(b))
Packer's constructor has some keyword arguments:
:param callable default:
Convert user type to builtin type that Packer supports.
See also simplejson's document.
:param bool use_single_float:
Use single precision float type for float. (default: False)
:param bool autoreset:
Reset buffer after each pack and return its content as `bytes`. (default: True).
If set this to false, use `bytes()` to get content and `.reset()` to clear buffer.
:param bool use_bin_type:
Use bin type introduced in msgpack spec 2.0 for bytes.
It also enables str8 type for unicode. (default: True)
:param bool strict_types:
If set to true, types will be checked to be exact. Derived classes
from serializeable types will not be serialized and will be
treated as unsupported type and forwarded to default.
Additionally tuples will not be serialized as lists.
This is useful when trying to implement accurate serialization
for python types.
:param bool datetime:
If set to true, datetime with tzinfo is packed into Timestamp type.
Note that the tzinfo is stripped in the timestamp.
You can get UTC datetime with `timestamp=3` option of the Unpacker.
(Python 2 is not supported).
:param str unicode_errors:
The error handler for encoding unicode. (default: 'strict')
DO NOT USE THIS!! This option is kept for very specific usage.
"""
cdef msgpack_packer pk
cdef object _default
cdef object _berrors
cdef const char *unicode_errors
cdef bint strict_types
cdef bint use_float
cdef bint autoreset
cdef bint datetime
def __cinit__(self):
cdef int buf_size = 1024*1024
self.pk.buf = <char*> PyMem_Malloc(buf_size)
if self.pk.buf == NULL:
raise MemoryError("Unable to allocate internal buffer.")
self.pk.buf_size = buf_size
self.pk.length = 0
def __init__(self, *, default=None,
bint use_single_float=False, bint autoreset=True, bint use_bin_type=True,
bint strict_types=False, bint datetime=False, unicode_errors=None):
self.use_float = use_single_float
self.strict_types = strict_types
self.autoreset = autoreset
self.datetime = datetime
self.pk.use_bin_type = use_bin_type
if default is not None:
if not PyCallable_Check(default):
raise TypeError("default must be a callable.")
self._default = default
self._berrors = unicode_errors
if unicode_errors is None:
self.unicode_errors = NULL
else:
self.unicode_errors = self._berrors
def __dealloc__(self):
PyMem_Free(self.pk.buf)
self.pk.buf = NULL
cdef int _pack(self, object o, int nest_limit=DEFAULT_RECURSE_LIMIT) except -1:
cdef long long llval
cdef unsigned long long ullval
cdef unsigned long ulval
cdef long longval
cdef float fval
cdef double dval
cdef char* rawval
cdef int ret
cdef dict d
cdef Py_ssize_t L
cdef int default_used = 0
cdef bint strict_types = self.strict_types
cdef Py_buffer view
if nest_limit < 0:
raise ValueError("recursion limit exceeded.")
while True:
if o is None:
ret = msgpack_pack_nil(&self.pk)
elif o is True:
ret = msgpack_pack_true(&self.pk)
elif o is False:
ret = msgpack_pack_false(&self.pk)
elif PyLong_CheckExact(o) if strict_types else PyLong_Check(o):
# PyInt_Check(long) is True for Python 3.
# So we should test long before int.
try:
if o > 0:
ullval = o
ret = msgpack_pack_unsigned_long_long(&self.pk, ullval)
else:
llval = o
ret = msgpack_pack_long_long(&self.pk, llval)
except OverflowError as oe:
if not default_used and self._default is not None:
o = self._default(o)
default_used = True
continue
else:
raise OverflowError("Integer value out of range")
elif PyInt_CheckExact(o) if strict_types else PyInt_Check(o):
longval = o
ret = msgpack_pack_long(&self.pk, longval)
elif PyFloat_CheckExact(o) if strict_types else PyFloat_Check(o):
if self.use_float:
fval = o
ret = msgpack_pack_float(&self.pk, fval)
else:
dval = o
ret = msgpack_pack_double(&self.pk, dval)
elif PyBytesLike_CheckExact(o) if strict_types else PyBytesLike_Check(o):
L = Py_SIZE(o)
if L > ITEM_LIMIT:
PyErr_Format(ValueError, b"%.200s object is too large", Py_TYPE(o).tp_name)
rawval = o
ret = msgpack_pack_bin(&self.pk, L)
if ret == 0:
ret = msgpack_pack_raw_body(&self.pk, rawval, L)
elif PyUnicode_CheckExact(o) if strict_types else PyUnicode_Check(o):
if self.unicode_errors == NULL:
ret = msgpack_pack_unicode(&self.pk, o, ITEM_LIMIT);
if ret == -2:
raise ValueError("unicode string is too large")
else:
o = PyUnicode_AsEncodedString(o, NULL, self.unicode_errors)
L = Py_SIZE(o)
if L > ITEM_LIMIT:
raise ValueError("unicode string is too large")
ret = msgpack_pack_raw(&self.pk, L)
if ret == 0:
rawval = o
ret = msgpack_pack_raw_body(&self.pk, rawval, L)
elif PyDict_CheckExact(o):
d = <dict>o
L = len(d)
if L > ITEM_LIMIT:
raise ValueError("dict is too large")
ret = msgpack_pack_map(&self.pk, L)
if ret == 0:
for k, v in d.items():
ret = self._pack(k, nest_limit-1)
if ret != 0: break
ret = self._pack(v, nest_limit-1)
if ret != 0: break
elif not strict_types and PyDict_Check(o):
L = len(o)
if L > ITEM_LIMIT:
raise ValueError("dict is too large")
ret = msgpack_pack_map(&self.pk, L)
if ret == 0:
for k, v in o.items():
ret = self._pack(k, nest_limit-1)
if ret != 0: break
ret = self._pack(v, nest_limit-1)
if ret != 0: break
elif type(o) is ExtType if strict_types else isinstance(o, ExtType):
# This should be before Tuple because ExtType is namedtuple.
longval = o.code
rawval = o.data
L = len(o.data)
if L > ITEM_LIMIT:
raise ValueError("EXT data is too large")
ret = msgpack_pack_ext(&self.pk, longval, L)
ret = msgpack_pack_raw_body(&self.pk, rawval, L)
elif type(o) is Timestamp:
llval = o.seconds
ulval = o.nanoseconds
ret = msgpack_pack_timestamp(&self.pk, llval, ulval)
elif PyList_CheckExact(o) if strict_types else (PyTuple_Check(o) or PyList_Check(o)):
L = Py_SIZE(o)
if L > ITEM_LIMIT:
raise ValueError("list is too large")
ret = msgpack_pack_array(&self.pk, L)
if ret == 0:
for v in o:
ret = self._pack(v, nest_limit-1)
if ret != 0: break
elif PyMemoryView_Check(o):
if PyObject_GetBuffer(o, &view, PyBUF_SIMPLE) != 0:
raise ValueError("could not get buffer for memoryview")
L = view.len
if L > ITEM_LIMIT:
PyBuffer_Release(&view);
raise ValueError("memoryview is too large")
ret = msgpack_pack_bin(&self.pk, L)
if ret == 0:
ret = msgpack_pack_raw_body(&self.pk, <char*>view.buf, L)
PyBuffer_Release(&view);
elif self.datetime and PyDateTime_CheckExact(o) and datetime_tzinfo(o) is not None:
delta = o - epoch
if not PyDelta_CheckExact(delta):
raise ValueError("failed to calculate delta")
llval = timedelta_days(delta) * <long long>(24*60*60) + timedelta_seconds(delta)
ulval = timedelta_microseconds(delta) * 1000
ret = msgpack_pack_timestamp(&self.pk, llval, ulval)
elif not default_used and self._default:
o = self._default(o)
default_used = 1
continue
else:
PyErr_Format(TypeError, b"can not serialize '%.200s' object", Py_TYPE(o).tp_name)
return ret
cpdef pack(self, object obj):
cdef int ret
try:
ret = self._pack(obj, DEFAULT_RECURSE_LIMIT)
except:
self.pk.length = 0
raise
if ret: # should not happen.
raise RuntimeError("internal error")
if self.autoreset:
buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
self.pk.length = 0
return buf
def pack_ext_type(self, typecode, data):
msgpack_pack_ext(&self.pk, typecode, len(data))
msgpack_pack_raw_body(&self.pk, data, len(data))
def pack_array_header(self, long long size):
if size > ITEM_LIMIT:
raise ValueError
cdef int ret = msgpack_pack_array(&self.pk, size)
if ret == -1:
raise MemoryError
elif ret: # should not happen
raise TypeError
if self.autoreset:
buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
self.pk.length = 0
return buf
def pack_map_header(self, long long size):
if size > ITEM_LIMIT:
raise ValueError
cdef int ret = msgpack_pack_map(&self.pk, size)
if ret == -1:
raise MemoryError
elif ret: # should not happen
raise TypeError
if self.autoreset:
buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
self.pk.length = 0
return buf
def pack_map_pairs(self, object pairs):
"""
Pack *pairs* as msgpack map type.
*pairs* should be a sequence of pairs.
(`len(pairs)` and `for k, v in pairs:` should be supported.)
"""
cdef int ret = msgpack_pack_map(&self.pk, len(pairs))
if ret == 0:
for k, v in pairs:
ret = self._pack(k)
if ret != 0: break
ret = self._pack(v)
if ret != 0: break
if ret == -1:
raise MemoryError
elif ret: # should not happen
raise TypeError
if self.autoreset:
buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
self.pk.length = 0
return buf
def reset(self):
"""Reset internal buffer.
This method is useful only when autoreset=False.
"""
self.pk.length = 0
def bytes(self):
"""Return internal buffer contents as bytes object"""
return PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
def getbuffer(self):
"""Return view of internal buffer."""
return buff_to_buff(self.pk.buf, self.pk.length)

@ -0,0 +1,551 @@
# coding: utf-8
from cpython cimport *
cdef extern from "Python.h":
ctypedef struct PyObject
object PyMemoryView_GetContiguous(object obj, int buffertype, char order)
from libc.stdlib cimport *
from libc.string cimport *
from libc.limits cimport *
from libc.stdint cimport uint64_t
from .exceptions import (
BufferFull,
OutOfData,
ExtraData,
FormatError,
StackError,
)
from .ext import ExtType, Timestamp
cdef object giga = 1_000_000_000
cdef extern from "unpack.h":
ctypedef struct msgpack_user:
bint use_list
bint raw
bint has_pairs_hook # call object_hook with k-v pairs
bint strict_map_key
int timestamp
PyObject* object_hook
PyObject* list_hook
PyObject* ext_hook
PyObject* timestamp_t
PyObject *giga;
PyObject *utc;
char *unicode_errors
Py_ssize_t max_str_len
Py_ssize_t max_bin_len
Py_ssize_t max_array_len
Py_ssize_t max_map_len
Py_ssize_t max_ext_len
ctypedef struct unpack_context:
msgpack_user user
PyObject* obj
Py_ssize_t count
ctypedef int (*execute_fn)(unpack_context* ctx, const char* data,
Py_ssize_t len, Py_ssize_t* off) except? -1
execute_fn unpack_construct
execute_fn unpack_skip
execute_fn read_array_header
execute_fn read_map_header
void unpack_init(unpack_context* ctx)
object unpack_data(unpack_context* ctx)
void unpack_clear(unpack_context* ctx)
cdef inline init_ctx(unpack_context *ctx,
object object_hook, object object_pairs_hook,
object list_hook, object ext_hook,
bint use_list, bint raw, int timestamp,
bint strict_map_key,
const char* unicode_errors,
Py_ssize_t max_str_len, Py_ssize_t max_bin_len,
Py_ssize_t max_array_len, Py_ssize_t max_map_len,
Py_ssize_t max_ext_len):
unpack_init(ctx)
ctx.user.use_list = use_list
ctx.user.raw = raw
ctx.user.strict_map_key = strict_map_key
ctx.user.object_hook = ctx.user.list_hook = <PyObject*>NULL
ctx.user.max_str_len = max_str_len
ctx.user.max_bin_len = max_bin_len
ctx.user.max_array_len = max_array_len
ctx.user.max_map_len = max_map_len
ctx.user.max_ext_len = max_ext_len
if object_hook is not None and object_pairs_hook is not None:
raise TypeError("object_pairs_hook and object_hook are mutually exclusive.")
if object_hook is not None:
if not PyCallable_Check(object_hook):
raise TypeError("object_hook must be a callable.")
ctx.user.object_hook = <PyObject*>object_hook
if object_pairs_hook is None:
ctx.user.has_pairs_hook = False
else:
if not PyCallable_Check(object_pairs_hook):
raise TypeError("object_pairs_hook must be a callable.")
ctx.user.object_hook = <PyObject*>object_pairs_hook
ctx.user.has_pairs_hook = True
if list_hook is not None:
if not PyCallable_Check(list_hook):
raise TypeError("list_hook must be a callable.")
ctx.user.list_hook = <PyObject*>list_hook
if ext_hook is not None:
if not PyCallable_Check(ext_hook):
raise TypeError("ext_hook must be a callable.")
ctx.user.ext_hook = <PyObject*>ext_hook
if timestamp < 0 or 3 < timestamp:
raise ValueError("timestamp must be 0..3")
# Add Timestamp type to the user object so it may be used in unpack.h
ctx.user.timestamp = timestamp
ctx.user.timestamp_t = <PyObject*>Timestamp
ctx.user.giga = <PyObject*>giga
ctx.user.utc = <PyObject*>utc
ctx.user.unicode_errors = unicode_errors
def default_read_extended_type(typecode, data):
raise NotImplementedError("Cannot decode extended type with typecode=%d" % typecode)
cdef inline int get_data_from_buffer(object obj,
Py_buffer *view,
char **buf,
Py_ssize_t *buffer_len) except 0:
cdef object contiguous
cdef Py_buffer tmp
if PyObject_GetBuffer(obj, view, PyBUF_FULL_RO) == -1:
raise
if view.itemsize != 1:
PyBuffer_Release(view)
raise BufferError("cannot unpack from multi-byte object")
if PyBuffer_IsContiguous(view, b'A') == 0:
PyBuffer_Release(view)
# create a contiguous copy and get buffer
contiguous = PyMemoryView_GetContiguous(obj, PyBUF_READ, b'C')
PyObject_GetBuffer(contiguous, view, PyBUF_SIMPLE)
# view must hold the only reference to contiguous,
# so memory is freed when view is released
Py_DECREF(contiguous)
buffer_len[0] = view.len
buf[0] = <char*> view.buf
return 1
def unpackb(object packed, *, object object_hook=None, object list_hook=None,
bint use_list=True, bint raw=False, int timestamp=0, bint strict_map_key=True,
unicode_errors=None,
object_pairs_hook=None, ext_hook=ExtType,
Py_ssize_t max_str_len=-1,
Py_ssize_t max_bin_len=-1,
Py_ssize_t max_array_len=-1,
Py_ssize_t max_map_len=-1,
Py_ssize_t max_ext_len=-1):
"""
Unpack packed_bytes to object. Returns an unpacked object.
Raises ``ExtraData`` when *packed* contains extra bytes.
Raises ``ValueError`` when *packed* is incomplete.
Raises ``FormatError`` when *packed* is not valid msgpack.
Raises ``StackError`` when *packed* contains too nested.
Other exceptions can be raised during unpacking.
See :class:`Unpacker` for options.
*max_xxx_len* options are configured automatically from ``len(packed)``.
"""
cdef unpack_context ctx
cdef Py_ssize_t off = 0
cdef int ret
cdef Py_buffer view
cdef char* buf = NULL
cdef Py_ssize_t buf_len
cdef const char* cerr = NULL
if unicode_errors is not None:
cerr = unicode_errors
get_data_from_buffer(packed, &view, &buf, &buf_len)
if max_str_len == -1:
max_str_len = buf_len
if max_bin_len == -1:
max_bin_len = buf_len
if max_array_len == -1:
max_array_len = buf_len
if max_map_len == -1:
max_map_len = buf_len//2
if max_ext_len == -1:
max_ext_len = buf_len
try:
init_ctx(&ctx, object_hook, object_pairs_hook, list_hook, ext_hook,
use_list, raw, timestamp, strict_map_key, cerr,
max_str_len, max_bin_len, max_array_len, max_map_len, max_ext_len)
ret = unpack_construct(&ctx, buf, buf_len, &off)
finally:
PyBuffer_Release(&view);
if ret == 1:
obj = unpack_data(&ctx)
if off < buf_len:
raise ExtraData(obj, PyBytes_FromStringAndSize(buf+off, buf_len-off))
return obj
unpack_clear(&ctx)
if ret == 0:
raise ValueError("Unpack failed: incomplete input")
elif ret == -2:
raise FormatError
elif ret == -3:
raise StackError
raise ValueError("Unpack failed: error = %d" % (ret,))
cdef class Unpacker(object):
"""Streaming unpacker.
Arguments:
:param file_like:
File-like object having `.read(n)` method.
If specified, unpacker reads serialized data from it and :meth:`feed()` is not usable.
:param int read_size:
Used as `file_like.read(read_size)`. (default: `min(16*1024, max_buffer_size)`)
:param bool use_list:
If true, unpack msgpack array to Python list.
Otherwise, unpack to Python tuple. (default: True)
:param bool raw:
If true, unpack msgpack raw to Python bytes.
Otherwise, unpack to Python str by decoding with UTF-8 encoding (default).
:param int timestamp:
Control how timestamp type is unpacked:
0 - Timestamp
1 - float (Seconds from the EPOCH)
2 - int (Nanoseconds from the EPOCH)
3 - datetime.datetime (UTC). Python 2 is not supported.
:param bool strict_map_key:
If true (default), only str or bytes are accepted for map (dict) keys.
:param callable object_hook:
When specified, it should be callable.
Unpacker calls it with a dict argument after unpacking msgpack map.
(See also simplejson)
:param callable object_pairs_hook:
When specified, it should be callable.
Unpacker calls it with a list of key-value pairs after unpacking msgpack map.
(See also simplejson)
:param str unicode_errors:
The error handler for decoding unicode. (default: 'strict')
This option should be used only when you have msgpack data which
contains invalid UTF-8 string.
:param int max_buffer_size:
Limits size of data waiting unpacked. 0 means 2**32-1.
The default value is 100*1024*1024 (100MiB).
Raises `BufferFull` exception when it is insufficient.
You should set this parameter when unpacking data from untrusted source.
:param int max_str_len:
Deprecated, use *max_buffer_size* instead.
Limits max length of str. (default: max_buffer_size)
:param int max_bin_len:
Deprecated, use *max_buffer_size* instead.
Limits max length of bin. (default: max_buffer_size)
:param int max_array_len:
Limits max length of array.
(default: max_buffer_size)
:param int max_map_len:
Limits max length of map.
(default: max_buffer_size//2)
:param int max_ext_len:
Deprecated, use *max_buffer_size* instead.
Limits max size of ext type. (default: max_buffer_size)
Example of streaming deserialize from file-like object::
unpacker = Unpacker(file_like)
for o in unpacker:
process(o)
Example of streaming deserialize from socket::
unpacker = Unpacker()
while True:
buf = sock.recv(1024**2)
if not buf:
break
unpacker.feed(buf)
for o in unpacker:
process(o)
Raises ``ExtraData`` when *packed* contains extra bytes.
Raises ``OutOfData`` when *packed* is incomplete.
Raises ``FormatError`` when *packed* is not valid msgpack.
Raises ``StackError`` when *packed* contains too nested.
Other exceptions can be raised during unpacking.
"""
cdef unpack_context ctx
cdef char* buf
cdef Py_ssize_t buf_size, buf_head, buf_tail
cdef object file_like
cdef object file_like_read
cdef Py_ssize_t read_size
# To maintain refcnt.
cdef object object_hook, object_pairs_hook, list_hook, ext_hook
cdef object unicode_errors
cdef Py_ssize_t max_buffer_size
cdef uint64_t stream_offset
def __cinit__(self):
self.buf = NULL
def __dealloc__(self):
PyMem_Free(self.buf)
self.buf = NULL
def __init__(self, file_like=None, *, Py_ssize_t read_size=0,
bint use_list=True, bint raw=False, int timestamp=0, bint strict_map_key=True,
object object_hook=None, object object_pairs_hook=None, object list_hook=None,
unicode_errors=None, Py_ssize_t max_buffer_size=100*1024*1024,
object ext_hook=ExtType,
Py_ssize_t max_str_len=-1,
Py_ssize_t max_bin_len=-1,
Py_ssize_t max_array_len=-1,
Py_ssize_t max_map_len=-1,
Py_ssize_t max_ext_len=-1):
cdef const char *cerr=NULL
self.object_hook = object_hook
self.object_pairs_hook = object_pairs_hook
self.list_hook = list_hook
self.ext_hook = ext_hook
self.file_like = file_like
if file_like:
self.file_like_read = file_like.read
if not PyCallable_Check(self.file_like_read):
raise TypeError("`file_like.read` must be a callable.")
if not max_buffer_size:
max_buffer_size = INT_MAX
if max_str_len == -1:
max_str_len = max_buffer_size
if max_bin_len == -1:
max_bin_len = max_buffer_size
if max_array_len == -1:
max_array_len = max_buffer_size
if max_map_len == -1:
max_map_len = max_buffer_size//2
if max_ext_len == -1:
max_ext_len = max_buffer_size
if read_size > max_buffer_size:
raise ValueError("read_size should be less or equal to max_buffer_size")
if not read_size:
read_size = min(max_buffer_size, 1024**2)
self.max_buffer_size = max_buffer_size
self.read_size = read_size
self.buf = <char*>PyMem_Malloc(read_size)
if self.buf == NULL:
raise MemoryError("Unable to allocate internal buffer.")
self.buf_size = read_size
self.buf_head = 0
self.buf_tail = 0
self.stream_offset = 0
if unicode_errors is not None:
self.unicode_errors = unicode_errors
cerr = unicode_errors
init_ctx(&self.ctx, object_hook, object_pairs_hook, list_hook,
ext_hook, use_list, raw, timestamp, strict_map_key, cerr,
max_str_len, max_bin_len, max_array_len,
max_map_len, max_ext_len)
def feed(self, object next_bytes):
"""Append `next_bytes` to internal buffer."""
cdef Py_buffer pybuff
cdef char* buf
cdef Py_ssize_t buf_len
if self.file_like is not None:
raise AssertionError(
"unpacker.feed() is not be able to use with `file_like`.")
get_data_from_buffer(next_bytes, &pybuff, &buf, &buf_len)
try:
self.append_buffer(buf, buf_len)
finally:
PyBuffer_Release(&pybuff)
cdef append_buffer(self, void* _buf, Py_ssize_t _buf_len):
cdef:
char* buf = self.buf
char* new_buf
Py_ssize_t head = self.buf_head
Py_ssize_t tail = self.buf_tail
Py_ssize_t buf_size = self.buf_size
Py_ssize_t new_size
if tail + _buf_len > buf_size:
if ((tail - head) + _buf_len) <= buf_size:
# move to front.
memmove(buf, buf + head, tail - head)
tail -= head
head = 0
else:
# expand buffer.
new_size = (tail-head) + _buf_len
if new_size > self.max_buffer_size:
raise BufferFull
new_size = min(new_size*2, self.max_buffer_size)
new_buf = <char*>PyMem_Malloc(new_size)
if new_buf == NULL:
# self.buf still holds old buffer and will be freed during
# obj destruction
raise MemoryError("Unable to enlarge internal buffer.")
memcpy(new_buf, buf + head, tail - head)
PyMem_Free(buf)
buf = new_buf
buf_size = new_size
tail -= head
head = 0
memcpy(buf + tail, <char*>(_buf), _buf_len)
self.buf = buf
self.buf_head = head
self.buf_size = buf_size
self.buf_tail = tail + _buf_len
cdef read_from_file(self):
next_bytes = self.file_like_read(
min(self.read_size,
self.max_buffer_size - (self.buf_tail - self.buf_head)
))
if next_bytes:
self.append_buffer(PyBytes_AsString(next_bytes), PyBytes_Size(next_bytes))
else:
self.file_like = None
cdef object _unpack(self, execute_fn execute, bint iter=0):
cdef int ret
cdef object obj
cdef Py_ssize_t prev_head
if self.buf_head >= self.buf_tail and self.file_like is not None:
self.read_from_file()
while 1:
prev_head = self.buf_head
if prev_head >= self.buf_tail:
if iter:
raise StopIteration("No more data to unpack.")
else:
raise OutOfData("No more data to unpack.")
ret = execute(&self.ctx, self.buf, self.buf_tail, &self.buf_head)
self.stream_offset += self.buf_head - prev_head
if ret == 1:
obj = unpack_data(&self.ctx)
unpack_init(&self.ctx)
return obj
elif ret == 0:
if self.file_like is not None:
self.read_from_file()
continue
if iter:
raise StopIteration("No more data to unpack.")
else:
raise OutOfData("No more data to unpack.")
elif ret == -2:
raise FormatError
elif ret == -3:
raise StackError
else:
raise ValueError("Unpack failed: error = %d" % (ret,))
def read_bytes(self, Py_ssize_t nbytes):
"""Read a specified number of raw bytes from the stream"""
cdef Py_ssize_t nread
nread = min(self.buf_tail - self.buf_head, nbytes)
ret = PyBytes_FromStringAndSize(self.buf + self.buf_head, nread)
self.buf_head += nread
if nread < nbytes and self.file_like is not None:
ret += self.file_like.read(nbytes - nread)
nread = len(ret)
self.stream_offset += nread
return ret
def unpack(self):
"""Unpack one object
Raises `OutOfData` when there are no more bytes to unpack.
"""
return self._unpack(unpack_construct)
def skip(self):
"""Read and ignore one object, returning None
Raises `OutOfData` when there are no more bytes to unpack.
"""
return self._unpack(unpack_skip)
def read_array_header(self):
"""assuming the next object is an array, return its size n, such that
the next n unpack() calls will iterate over its contents.
Raises `OutOfData` when there are no more bytes to unpack.
"""
return self._unpack(read_array_header)
def read_map_header(self):
"""assuming the next object is a map, return its size n, such that the
next n * 2 unpack() calls will iterate over its key-value pairs.
Raises `OutOfData` when there are no more bytes to unpack.
"""
return self._unpack(read_map_header)
def tell(self):
"""Returns the current position of the Unpacker in bytes, i.e., the
number of bytes that were read from the input, also the starting
position of the next object.
"""
return self.stream_offset
def __iter__(self):
return self
def __next__(self):
return self._unpack(unpack_construct, 1)
# for debug.
#def _buf(self):
# return PyString_FromStringAndSize(self.buf, self.buf_tail)
#def _off(self):
# return self.buf_head

@ -0,0 +1 @@
version = (1, 0, 2)

@ -0,0 +1,8 @@
#include "Python.h"
/* cython does not support this preprocessor check => write it in raw C */
static PyObject *
buff_to_buff(char *buff, Py_ssize_t size)
{
return PyMemoryView_FromMemory(buff, size, PyBUF_READ);
}

@ -0,0 +1,48 @@
class UnpackException(Exception):
"""Base class for some exceptions raised while unpacking.
NOTE: unpack may raise exception other than subclass of
UnpackException. If you want to catch all error, catch
Exception instead.
"""
class BufferFull(UnpackException):
pass
class OutOfData(UnpackException):
pass
class FormatError(ValueError, UnpackException):
"""Invalid msgpack format"""
class StackError(ValueError, UnpackException):
"""Too nested"""
# Deprecated. Use ValueError instead
UnpackValueError = ValueError
class ExtraData(UnpackValueError):
"""ExtraData is raised when there is trailing data.
This exception is raised while only one-shot (not streaming)
unpack.
"""
def __init__(self, unpacked, extra):
self.unpacked = unpacked
self.extra = extra
def __str__(self):
return "unpack(b) received extra data."
# Deprecated. Use Exception instead to catch all exception during packing.
PackException = Exception
PackValueError = ValueError
PackOverflowError = OverflowError

@ -0,0 +1,193 @@
# coding: utf-8
from collections import namedtuple
import datetime
import sys
import struct
PY2 = sys.version_info[0] == 2
if PY2:
int_types = (int, long)
_utc = None
else:
int_types = int
try:
_utc = datetime.timezone.utc
except AttributeError:
_utc = datetime.timezone(datetime.timedelta(0))
class ExtType(namedtuple("ExtType", "code data")):
"""ExtType represents ext type in msgpack."""
def __new__(cls, code, data):
if not isinstance(code, int):
raise TypeError("code must be int")
if not isinstance(data, bytes):
raise TypeError("data must be bytes")
if not 0 <= code <= 127:
raise ValueError("code must be 0~127")
return super(ExtType, cls).__new__(cls, code, data)
class Timestamp(object):
"""Timestamp represents the Timestamp extension type in msgpack.
When built with Cython, msgpack uses C methods to pack and unpack `Timestamp`. When using pure-Python
msgpack, :func:`to_bytes` and :func:`from_bytes` are used to pack and unpack `Timestamp`.
This class is immutable: Do not override seconds and nanoseconds.
"""
__slots__ = ["seconds", "nanoseconds"]
def __init__(self, seconds, nanoseconds=0):
"""Initialize a Timestamp object.
:param int seconds:
Number of seconds since the UNIX epoch (00:00:00 UTC Jan 1 1970, minus leap seconds).
May be negative.
:param int nanoseconds:
Number of nanoseconds to add to `seconds` to get fractional time.
Maximum is 999_999_999. Default is 0.
Note: Negative times (before the UNIX epoch) are represented as negative seconds + positive ns.
"""
if not isinstance(seconds, int_types):
raise TypeError("seconds must be an interger")
if not isinstance(nanoseconds, int_types):
raise TypeError("nanoseconds must be an integer")
if not (0 <= nanoseconds < 10 ** 9):
raise ValueError(
"nanoseconds must be a non-negative integer less than 999999999."
)
self.seconds = seconds
self.nanoseconds = nanoseconds
def __repr__(self):
"""String representation of Timestamp."""
return "Timestamp(seconds={0}, nanoseconds={1})".format(
self.seconds, self.nanoseconds
)
def __eq__(self, other):
"""Check for equality with another Timestamp object"""
if type(other) is self.__class__:
return (
self.seconds == other.seconds and self.nanoseconds == other.nanoseconds
)
return False
def __ne__(self, other):
"""not-equals method (see :func:`__eq__()`)"""
return not self.__eq__(other)
def __hash__(self):
return hash((self.seconds, self.nanoseconds))
@staticmethod
def from_bytes(b):
"""Unpack bytes into a `Timestamp` object.
Used for pure-Python msgpack unpacking.
:param b: Payload from msgpack ext message with code -1
:type b: bytes
:returns: Timestamp object unpacked from msgpack ext payload
:rtype: Timestamp
"""
if len(b) == 4:
seconds = struct.unpack("!L", b)[0]
nanoseconds = 0
elif len(b) == 8:
data64 = struct.unpack("!Q", b)[0]
seconds = data64 & 0x00000003FFFFFFFF
nanoseconds = data64 >> 34
elif len(b) == 12:
nanoseconds, seconds = struct.unpack("!Iq", b)
else:
raise ValueError(
"Timestamp type can only be created from 32, 64, or 96-bit byte objects"
)
return Timestamp(seconds, nanoseconds)
def to_bytes(self):
"""Pack this Timestamp object into bytes.
Used for pure-Python msgpack packing.
:returns data: Payload for EXT message with code -1 (timestamp type)
:rtype: bytes
"""
if (self.seconds >> 34) == 0: # seconds is non-negative and fits in 34 bits
data64 = self.nanoseconds << 34 | self.seconds
if data64 & 0xFFFFFFFF00000000 == 0:
# nanoseconds is zero and seconds < 2**32, so timestamp 32
data = struct.pack("!L", data64)
else:
# timestamp 64
data = struct.pack("!Q", data64)
else:
# timestamp 96
data = struct.pack("!Iq", self.nanoseconds, self.seconds)
return data
@staticmethod
def from_unix(unix_sec):
"""Create a Timestamp from posix timestamp in seconds.
:param unix_float: Posix timestamp in seconds.
:type unix_float: int or float.
"""
seconds = int(unix_sec // 1)
nanoseconds = int((unix_sec % 1) * 10 ** 9)
return Timestamp(seconds, nanoseconds)
def to_unix(self):
"""Get the timestamp as a floating-point value.
:returns: posix timestamp
:rtype: float
"""
return self.seconds + self.nanoseconds / 1e9
@staticmethod
def from_unix_nano(unix_ns):
"""Create a Timestamp from posix timestamp in nanoseconds.
:param int unix_ns: Posix timestamp in nanoseconds.
:rtype: Timestamp
"""
return Timestamp(*divmod(unix_ns, 10 ** 9))
def to_unix_nano(self):
"""Get the timestamp as a unixtime in nanoseconds.
:returns: posix timestamp in nanoseconds
:rtype: int
"""
return self.seconds * 10 ** 9 + self.nanoseconds
def to_datetime(self):
"""Get the timestamp as a UTC datetime.
Python 2 is not supported.
:rtype: datetime.
"""
return datetime.datetime.fromtimestamp(0, _utc) + datetime.timedelta(
seconds=self.to_unix()
)
@staticmethod
def from_datetime(dt):
"""Create a Timestamp from datetime with tzinfo.
Python 2 is not supported.
:rtype: Timestamp
"""
return Timestamp.from_unix(dt.timestamp())

File diff suppressed because it is too large Load Diff

@ -0,0 +1,119 @@
/*
* MessagePack for Python packing routine
*
* Copyright (C) 2009 Naoki INADA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stddef.h>
#include <stdlib.h>
#include "sysdep.h"
#include <limits.h>
#include <string.h>
#ifdef __cplusplus
extern "C" {
#endif
#ifdef _MSC_VER
#define inline __inline
#endif
typedef struct msgpack_packer {
char *buf;
size_t length;
size_t buf_size;
bool use_bin_type;
} msgpack_packer;
typedef struct Packer Packer;
static inline int msgpack_pack_write(msgpack_packer* pk, const char *data, size_t l)
{
char* buf = pk->buf;
size_t bs = pk->buf_size;
size_t len = pk->length;
if (len + l > bs) {
bs = (len + l) * 2;
buf = (char*)PyMem_Realloc(buf, bs);
if (!buf) {
PyErr_NoMemory();
return -1;
}
}
memcpy(buf + len, data, l);
len += l;
pk->buf = buf;
pk->buf_size = bs;
pk->length = len;
return 0;
}
#define msgpack_pack_append_buffer(user, buf, len) \
return msgpack_pack_write(user, (const char*)buf, len)
#include "pack_template.h"
// return -2 when o is too long
static inline int
msgpack_pack_unicode(msgpack_packer *pk, PyObject *o, long long limit)
{
#if PY_MAJOR_VERSION >= 3
assert(PyUnicode_Check(o));
Py_ssize_t len;
const char* buf = PyUnicode_AsUTF8AndSize(o, &len);
if (buf == NULL)
return -1;
if (len > limit) {
return -2;
}
int ret = msgpack_pack_raw(pk, len);
if (ret) return ret;
return msgpack_pack_raw_body(pk, buf, len);
#else
PyObject *bytes;
Py_ssize_t len;
int ret;
// py2
bytes = PyUnicode_AsUTF8String(o);
if (bytes == NULL)
return -1;
len = PyString_GET_SIZE(bytes);
if (len > limit) {
Py_DECREF(bytes);
return -2;
}
ret = msgpack_pack_raw(pk, len);
if (ret) {
Py_DECREF(bytes);
return -1;
}
ret = msgpack_pack_raw_body(pk, PyString_AS_STRING(bytes), len);
Py_DECREF(bytes);
return ret;
#endif
}
#ifdef __cplusplus
}
#endif

@ -0,0 +1,811 @@
/*
* MessagePack packing routine template
*
* Copyright (C) 2008-2010 FURUHASHI Sadayuki
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if defined(__LITTLE_ENDIAN__)
#define TAKE8_8(d) ((uint8_t*)&d)[0]
#define TAKE8_16(d) ((uint8_t*)&d)[0]
#define TAKE8_32(d) ((uint8_t*)&d)[0]
#define TAKE8_64(d) ((uint8_t*)&d)[0]
#elif defined(__BIG_ENDIAN__)
#define TAKE8_8(d) ((uint8_t*)&d)[0]
#define TAKE8_16(d) ((uint8_t*)&d)[1]
#define TAKE8_32(d) ((uint8_t*)&d)[3]
#define TAKE8_64(d) ((uint8_t*)&d)[7]
#endif
#ifndef msgpack_pack_append_buffer
#error msgpack_pack_append_buffer callback is not defined
#endif
/*
* Integer
*/
#define msgpack_pack_real_uint8(x, d) \
do { \
if(d < (1<<7)) { \
/* fixnum */ \
msgpack_pack_append_buffer(x, &TAKE8_8(d), 1); \
} else { \
/* unsigned 8 */ \
unsigned char buf[2] = {0xcc, TAKE8_8(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} \
} while(0)
#define msgpack_pack_real_uint16(x, d) \
do { \
if(d < (1<<7)) { \
/* fixnum */ \
msgpack_pack_append_buffer(x, &TAKE8_16(d), 1); \
} else if(d < (1<<8)) { \
/* unsigned 8 */ \
unsigned char buf[2] = {0xcc, TAKE8_16(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} else { \
/* unsigned 16 */ \
unsigned char buf[3]; \
buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} \
} while(0)
#define msgpack_pack_real_uint32(x, d) \
do { \
if(d < (1<<8)) { \
if(d < (1<<7)) { \
/* fixnum */ \
msgpack_pack_append_buffer(x, &TAKE8_32(d), 1); \
} else { \
/* unsigned 8 */ \
unsigned char buf[2] = {0xcc, TAKE8_32(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} \
} else { \
if(d < (1<<16)) { \
/* unsigned 16 */ \
unsigned char buf[3]; \
buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} else { \
/* unsigned 32 */ \
unsigned char buf[5]; \
buf[0] = 0xce; _msgpack_store32(&buf[1], (uint32_t)d); \
msgpack_pack_append_buffer(x, buf, 5); \
} \
} \
} while(0)
#define msgpack_pack_real_uint64(x, d) \
do { \
if(d < (1ULL<<8)) { \
if(d < (1ULL<<7)) { \
/* fixnum */ \
msgpack_pack_append_buffer(x, &TAKE8_64(d), 1); \
} else { \
/* unsigned 8 */ \
unsigned char buf[2] = {0xcc, TAKE8_64(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} \
} else { \
if(d < (1ULL<<16)) { \
/* unsigned 16 */ \
unsigned char buf[3]; \
buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} else if(d < (1ULL<<32)) { \
/* unsigned 32 */ \
unsigned char buf[5]; \
buf[0] = 0xce; _msgpack_store32(&buf[1], (uint32_t)d); \
msgpack_pack_append_buffer(x, buf, 5); \
} else { \
/* unsigned 64 */ \
unsigned char buf[9]; \
buf[0] = 0xcf; _msgpack_store64(&buf[1], d); \
msgpack_pack_append_buffer(x, buf, 9); \
} \
} \
} while(0)
#define msgpack_pack_real_int8(x, d) \
do { \
if(d < -(1<<5)) { \
/* signed 8 */ \
unsigned char buf[2] = {0xd0, TAKE8_8(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} else { \
/* fixnum */ \
msgpack_pack_append_buffer(x, &TAKE8_8(d), 1); \
} \
} while(0)
#define msgpack_pack_real_int16(x, d) \
do { \
if(d < -(1<<5)) { \
if(d < -(1<<7)) { \
/* signed 16 */ \
unsigned char buf[3]; \
buf[0] = 0xd1; _msgpack_store16(&buf[1], (int16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} else { \
/* signed 8 */ \
unsigned char buf[2] = {0xd0, TAKE8_16(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} \
} else if(d < (1<<7)) { \
/* fixnum */ \
msgpack_pack_append_buffer(x, &TAKE8_16(d), 1); \
} else { \
if(d < (1<<8)) { \
/* unsigned 8 */ \
unsigned char buf[2] = {0xcc, TAKE8_16(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} else { \
/* unsigned 16 */ \
unsigned char buf[3]; \
buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} \
} \
} while(0)
#define msgpack_pack_real_int32(x, d) \
do { \
if(d < -(1<<5)) { \
if(d < -(1<<15)) { \
/* signed 32 */ \
unsigned char buf[5]; \
buf[0] = 0xd2; _msgpack_store32(&buf[1], (int32_t)d); \
msgpack_pack_append_buffer(x, buf, 5); \
} else if(d < -(1<<7)) { \
/* signed 16 */ \
unsigned char buf[3]; \
buf[0] = 0xd1; _msgpack_store16(&buf[1], (int16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} else { \
/* signed 8 */ \
unsigned char buf[2] = {0xd0, TAKE8_32(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} \
} else if(d < (1<<7)) { \
/* fixnum */ \
msgpack_pack_append_buffer(x, &TAKE8_32(d), 1); \
} else { \
if(d < (1<<8)) { \
/* unsigned 8 */ \
unsigned char buf[2] = {0xcc, TAKE8_32(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} else if(d < (1<<16)) { \
/* unsigned 16 */ \
unsigned char buf[3]; \
buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} else { \
/* unsigned 32 */ \
unsigned char buf[5]; \
buf[0] = 0xce; _msgpack_store32(&buf[1], (uint32_t)d); \
msgpack_pack_append_buffer(x, buf, 5); \
} \
} \
} while(0)
#define msgpack_pack_real_int64(x, d) \
do { \
if(d < -(1LL<<5)) { \
if(d < -(1LL<<15)) { \
if(d < -(1LL<<31)) { \
/* signed 64 */ \
unsigned char buf[9]; \
buf[0] = 0xd3; _msgpack_store64(&buf[1], d); \
msgpack_pack_append_buffer(x, buf, 9); \
} else { \
/* signed 32 */ \
unsigned char buf[5]; \
buf[0] = 0xd2; _msgpack_store32(&buf[1], (int32_t)d); \
msgpack_pack_append_buffer(x, buf, 5); \
} \
} else { \
if(d < -(1<<7)) { \
/* signed 16 */ \
unsigned char buf[3]; \
buf[0] = 0xd1; _msgpack_store16(&buf[1], (int16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} else { \
/* signed 8 */ \
unsigned char buf[2] = {0xd0, TAKE8_64(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} \
} \
} else if(d < (1<<7)) { \
/* fixnum */ \
msgpack_pack_append_buffer(x, &TAKE8_64(d), 1); \
} else { \
if(d < (1LL<<16)) { \
if(d < (1<<8)) { \
/* unsigned 8 */ \
unsigned char buf[2] = {0xcc, TAKE8_64(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} else { \
/* unsigned 16 */ \
unsigned char buf[3]; \
buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} \
} else { \
if(d < (1LL<<32)) { \
/* unsigned 32 */ \
unsigned char buf[5]; \
buf[0] = 0xce; _msgpack_store32(&buf[1], (uint32_t)d); \
msgpack_pack_append_buffer(x, buf, 5); \
} else { \
/* unsigned 64 */ \
unsigned char buf[9]; \
buf[0] = 0xcf; _msgpack_store64(&buf[1], d); \
msgpack_pack_append_buffer(x, buf, 9); \
} \
} \
} \
} while(0)
static inline int msgpack_pack_uint8(msgpack_packer* x, uint8_t d)
{
msgpack_pack_real_uint8(x, d);
}
static inline int msgpack_pack_uint16(msgpack_packer* x, uint16_t d)
{
msgpack_pack_real_uint16(x, d);
}
static inline int msgpack_pack_uint32(msgpack_packer* x, uint32_t d)
{
msgpack_pack_real_uint32(x, d);
}
static inline int msgpack_pack_uint64(msgpack_packer* x, uint64_t d)
{
msgpack_pack_real_uint64(x, d);
}
static inline int msgpack_pack_int8(msgpack_packer* x, int8_t d)
{
msgpack_pack_real_int8(x, d);
}
static inline int msgpack_pack_int16(msgpack_packer* x, int16_t d)
{
msgpack_pack_real_int16(x, d);
}
static inline int msgpack_pack_int32(msgpack_packer* x, int32_t d)
{
msgpack_pack_real_int32(x, d);
}
static inline int msgpack_pack_int64(msgpack_packer* x, int64_t d)
{
msgpack_pack_real_int64(x, d);
}
//#ifdef msgpack_pack_inline_func_cint
static inline int msgpack_pack_short(msgpack_packer* x, short d)
{
#if defined(SIZEOF_SHORT)
#if SIZEOF_SHORT == 2
msgpack_pack_real_int16(x, d);
#elif SIZEOF_SHORT == 4
msgpack_pack_real_int32(x, d);
#else
msgpack_pack_real_int64(x, d);
#endif
#elif defined(SHRT_MAX)
#if SHRT_MAX == 0x7fff
msgpack_pack_real_int16(x, d);
#elif SHRT_MAX == 0x7fffffff
msgpack_pack_real_int32(x, d);
#else
msgpack_pack_real_int64(x, d);
#endif
#else
if(sizeof(short) == 2) {
msgpack_pack_real_int16(x, d);
} else if(sizeof(short) == 4) {
msgpack_pack_real_int32(x, d);
} else {
msgpack_pack_real_int64(x, d);
}
#endif
}
static inline int msgpack_pack_int(msgpack_packer* x, int d)
{
#if defined(SIZEOF_INT)
#if SIZEOF_INT == 2
msgpack_pack_real_int16(x, d);
#elif SIZEOF_INT == 4
msgpack_pack_real_int32(x, d);
#else
msgpack_pack_real_int64(x, d);
#endif
#elif defined(INT_MAX)
#if INT_MAX == 0x7fff
msgpack_pack_real_int16(x, d);
#elif INT_MAX == 0x7fffffff
msgpack_pack_real_int32(x, d);
#else
msgpack_pack_real_int64(x, d);
#endif
#else
if(sizeof(int) == 2) {
msgpack_pack_real_int16(x, d);
} else if(sizeof(int) == 4) {
msgpack_pack_real_int32(x, d);
} else {
msgpack_pack_real_int64(x, d);
}
#endif
}
static inline int msgpack_pack_long(msgpack_packer* x, long d)
{
#if defined(SIZEOF_LONG)
#if SIZEOF_LONG == 2
msgpack_pack_real_int16(x, d);
#elif SIZEOF_LONG == 4
msgpack_pack_real_int32(x, d);
#else
msgpack_pack_real_int64(x, d);
#endif
#elif defined(LONG_MAX)
#if LONG_MAX == 0x7fffL
msgpack_pack_real_int16(x, d);
#elif LONG_MAX == 0x7fffffffL
msgpack_pack_real_int32(x, d);
#else
msgpack_pack_real_int64(x, d);
#endif
#else
if(sizeof(long) == 2) {
msgpack_pack_real_int16(x, d);
} else if(sizeof(long) == 4) {
msgpack_pack_real_int32(x, d);
} else {
msgpack_pack_real_int64(x, d);
}
#endif
}
static inline int msgpack_pack_long_long(msgpack_packer* x, long long d)
{
#if defined(SIZEOF_LONG_LONG)
#if SIZEOF_LONG_LONG == 2
msgpack_pack_real_int16(x, d);
#elif SIZEOF_LONG_LONG == 4
msgpack_pack_real_int32(x, d);
#else
msgpack_pack_real_int64(x, d);
#endif
#elif defined(LLONG_MAX)
#if LLONG_MAX == 0x7fffL
msgpack_pack_real_int16(x, d);
#elif LLONG_MAX == 0x7fffffffL
msgpack_pack_real_int32(x, d);
#else
msgpack_pack_real_int64(x, d);
#endif
#else
if(sizeof(long long) == 2) {
msgpack_pack_real_int16(x, d);
} else if(sizeof(long long) == 4) {
msgpack_pack_real_int32(x, d);
} else {
msgpack_pack_real_int64(x, d);
}
#endif
}
static inline int msgpack_pack_unsigned_short(msgpack_packer* x, unsigned short d)
{
#if defined(SIZEOF_SHORT)
#if SIZEOF_SHORT == 2
msgpack_pack_real_uint16(x, d);
#elif SIZEOF_SHORT == 4
msgpack_pack_real_uint32(x, d);
#else
msgpack_pack_real_uint64(x, d);
#endif
#elif defined(USHRT_MAX)
#if USHRT_MAX == 0xffffU
msgpack_pack_real_uint16(x, d);
#elif USHRT_MAX == 0xffffffffU
msgpack_pack_real_uint32(x, d);
#else
msgpack_pack_real_uint64(x, d);
#endif
#else
if(sizeof(unsigned short) == 2) {
msgpack_pack_real_uint16(x, d);
} else if(sizeof(unsigned short) == 4) {
msgpack_pack_real_uint32(x, d);
} else {
msgpack_pack_real_uint64(x, d);
}
#endif
}
static inline int msgpack_pack_unsigned_int(msgpack_packer* x, unsigned int d)
{
#if defined(SIZEOF_INT)
#if SIZEOF_INT == 2
msgpack_pack_real_uint16(x, d);
#elif SIZEOF_INT == 4
msgpack_pack_real_uint32(x, d);
#else
msgpack_pack_real_uint64(x, d);
#endif
#elif defined(UINT_MAX)
#if UINT_MAX == 0xffffU
msgpack_pack_real_uint16(x, d);
#elif UINT_MAX == 0xffffffffU
msgpack_pack_real_uint32(x, d);
#else
msgpack_pack_real_uint64(x, d);
#endif
#else
if(sizeof(unsigned int) == 2) {
msgpack_pack_real_uint16(x, d);
} else if(sizeof(unsigned int) == 4) {
msgpack_pack_real_uint32(x, d);
} else {
msgpack_pack_real_uint64(x, d);
}
#endif
}
static inline int msgpack_pack_unsigned_long(msgpack_packer* x, unsigned long d)
{
#if defined(SIZEOF_LONG)
#if SIZEOF_LONG == 2
msgpack_pack_real_uint16(x, d);
#elif SIZEOF_LONG == 4
msgpack_pack_real_uint32(x, d);
#else
msgpack_pack_real_uint64(x, d);
#endif
#elif defined(ULONG_MAX)
#if ULONG_MAX == 0xffffUL
msgpack_pack_real_uint16(x, d);
#elif ULONG_MAX == 0xffffffffUL
msgpack_pack_real_uint32(x, d);
#else
msgpack_pack_real_uint64(x, d);
#endif
#else
if(sizeof(unsigned long) == 2) {
msgpack_pack_real_uint16(x, d);
} else if(sizeof(unsigned long) == 4) {
msgpack_pack_real_uint32(x, d);
} else {
msgpack_pack_real_uint64(x, d);
}
#endif
}
static inline int msgpack_pack_unsigned_long_long(msgpack_packer* x, unsigned long long d)
{
#if defined(SIZEOF_LONG_LONG)
#if SIZEOF_LONG_LONG == 2
msgpack_pack_real_uint16(x, d);
#elif SIZEOF_LONG_LONG == 4
msgpack_pack_real_uint32(x, d);
#else
msgpack_pack_real_uint64(x, d);
#endif
#elif defined(ULLONG_MAX)
#if ULLONG_MAX == 0xffffUL
msgpack_pack_real_uint16(x, d);
#elif ULLONG_MAX == 0xffffffffUL
msgpack_pack_real_uint32(x, d);
#else
msgpack_pack_real_uint64(x, d);
#endif
#else
if(sizeof(unsigned long long) == 2) {
msgpack_pack_real_uint16(x, d);
} else if(sizeof(unsigned long long) == 4) {
msgpack_pack_real_uint32(x, d);
} else {
msgpack_pack_real_uint64(x, d);
}
#endif
}
//#undef msgpack_pack_inline_func_cint
//#endif
/*
* Float
*/
static inline int msgpack_pack_float(msgpack_packer* x, float d)
{
unsigned char buf[5];
buf[0] = 0xca;
_PyFloat_Pack4(d, &buf[1], 0);
msgpack_pack_append_buffer(x, buf, 5);
}
static inline int msgpack_pack_double(msgpack_packer* x, double d)
{
unsigned char buf[9];
buf[0] = 0xcb;
_PyFloat_Pack8(d, &buf[1], 0);
msgpack_pack_append_buffer(x, buf, 9);
}
/*
* Nil
*/
static inline int msgpack_pack_nil(msgpack_packer* x)
{
static const unsigned char d = 0xc0;
msgpack_pack_append_buffer(x, &d, 1);
}
/*
* Boolean
*/
static inline int msgpack_pack_true(msgpack_packer* x)
{
static const unsigned char d = 0xc3;
msgpack_pack_append_buffer(x, &d, 1);
}
static inline int msgpack_pack_false(msgpack_packer* x)
{
static const unsigned char d = 0xc2;
msgpack_pack_append_buffer(x, &d, 1);
}
/*
* Array
*/
static inline int msgpack_pack_array(msgpack_packer* x, unsigned int n)
{
if(n < 16) {
unsigned char d = 0x90 | n;
msgpack_pack_append_buffer(x, &d, 1);
} else if(n < 65536) {
unsigned char buf[3];
buf[0] = 0xdc; _msgpack_store16(&buf[1], (uint16_t)n);
msgpack_pack_append_buffer(x, buf, 3);
} else {
unsigned char buf[5];
buf[0] = 0xdd; _msgpack_store32(&buf[1], (uint32_t)n);
msgpack_pack_append_buffer(x, buf, 5);
}
}
/*
* Map
*/
static inline int msgpack_pack_map(msgpack_packer* x, unsigned int n)
{
if(n < 16) {
unsigned char d = 0x80 | n;
msgpack_pack_append_buffer(x, &TAKE8_8(d), 1);
} else if(n < 65536) {
unsigned char buf[3];
buf[0] = 0xde; _msgpack_store16(&buf[1], (uint16_t)n);
msgpack_pack_append_buffer(x, buf, 3);
} else {
unsigned char buf[5];
buf[0] = 0xdf; _msgpack_store32(&buf[1], (uint32_t)n);
msgpack_pack_append_buffer(x, buf, 5);
}
}
/*
* Raw
*/
static inline int msgpack_pack_raw(msgpack_packer* x, size_t l)
{
if (l < 32) {
unsigned char d = 0xa0 | (uint8_t)l;
msgpack_pack_append_buffer(x, &TAKE8_8(d), 1);
} else if (x->use_bin_type && l < 256) { // str8 is new format introduced with bin.
unsigned char buf[2] = {0xd9, (uint8_t)l};
msgpack_pack_append_buffer(x, buf, 2);
} else if (l < 65536) {
unsigned char buf[3];
buf[0] = 0xda; _msgpack_store16(&buf[1], (uint16_t)l);
msgpack_pack_append_buffer(x, buf, 3);
} else {
unsigned char buf[5];
buf[0] = 0xdb; _msgpack_store32(&buf[1], (uint32_t)l);
msgpack_pack_append_buffer(x, buf, 5);
}
}
/*
* bin
*/
static inline int msgpack_pack_bin(msgpack_packer *x, size_t l)
{
if (!x->use_bin_type) {
return msgpack_pack_raw(x, l);
}
if (l < 256) {
unsigned char buf[2] = {0xc4, (unsigned char)l};
msgpack_pack_append_buffer(x, buf, 2);
} else if (l < 65536) {
unsigned char buf[3] = {0xc5};
_msgpack_store16(&buf[1], (uint16_t)l);
msgpack_pack_append_buffer(x, buf, 3);
} else {
unsigned char buf[5] = {0xc6};
_msgpack_store32(&buf[1], (uint32_t)l);
msgpack_pack_append_buffer(x, buf, 5);
}
}
static inline int msgpack_pack_raw_body(msgpack_packer* x, const void* b, size_t l)
{
if (l > 0) msgpack_pack_append_buffer(x, (const unsigned char*)b, l);
return 0;
}
/*
* Ext
*/
static inline int msgpack_pack_ext(msgpack_packer* x, char typecode, size_t l)
{
if (l == 1) {
unsigned char buf[2];
buf[0] = 0xd4;
buf[1] = (unsigned char)typecode;
msgpack_pack_append_buffer(x, buf, 2);
}
else if(l == 2) {
unsigned char buf[2];
buf[0] = 0xd5;
buf[1] = (unsigned char)typecode;
msgpack_pack_append_buffer(x, buf, 2);
}
else if(l == 4) {
unsigned char buf[2];
buf[0] = 0xd6;
buf[1] = (unsigned char)typecode;
msgpack_pack_append_buffer(x, buf, 2);
}
else if(l == 8) {
unsigned char buf[2];
buf[0] = 0xd7;
buf[1] = (unsigned char)typecode;
msgpack_pack_append_buffer(x, buf, 2);
}
else if(l == 16) {
unsigned char buf[2];
buf[0] = 0xd8;
buf[1] = (unsigned char)typecode;
msgpack_pack_append_buffer(x, buf, 2);
}
else if(l < 256) {
unsigned char buf[3];
buf[0] = 0xc7;
buf[1] = l;
buf[2] = (unsigned char)typecode;
msgpack_pack_append_buffer(x, buf, 3);
} else if(l < 65536) {
unsigned char buf[4];
buf[0] = 0xc8;
_msgpack_store16(&buf[1], (uint16_t)l);
buf[3] = (unsigned char)typecode;
msgpack_pack_append_buffer(x, buf, 4);
} else {
unsigned char buf[6];
buf[0] = 0xc9;
_msgpack_store32(&buf[1], (uint32_t)l);
buf[5] = (unsigned char)typecode;
msgpack_pack_append_buffer(x, buf, 6);
}
}
/*
* Pack Timestamp extension type. Follows msgpack-c pack_template.h.
*/
static inline int msgpack_pack_timestamp(msgpack_packer* x, int64_t seconds, uint32_t nanoseconds)
{
if ((seconds >> 34) == 0) {
/* seconds is unsigned and fits in 34 bits */
uint64_t data64 = ((uint64_t)nanoseconds << 34) | (uint64_t)seconds;
if ((data64 & 0xffffffff00000000L) == 0) {
/* no nanoseconds and seconds is 32bits or smaller. timestamp32. */
unsigned char buf[4];
uint32_t data32 = (uint32_t)data64;
msgpack_pack_ext(x, -1, 4);
_msgpack_store32(buf, data32);
msgpack_pack_raw_body(x, buf, 4);
} else {
/* timestamp64 */
unsigned char buf[8];
msgpack_pack_ext(x, -1, 8);
_msgpack_store64(buf, data64);
msgpack_pack_raw_body(x, buf, 8);
}
} else {
/* seconds is signed or >34bits */
unsigned char buf[12];
_msgpack_store32(&buf[0], nanoseconds);
_msgpack_store64(&buf[4], seconds);
msgpack_pack_ext(x, -1, 12);
msgpack_pack_raw_body(x, buf, 12);
}
return 0;
}
#undef msgpack_pack_append_buffer
#undef TAKE8_8
#undef TAKE8_16
#undef TAKE8_32
#undef TAKE8_64
#undef msgpack_pack_real_uint8
#undef msgpack_pack_real_uint16
#undef msgpack_pack_real_uint32
#undef msgpack_pack_real_uint64
#undef msgpack_pack_real_int8
#undef msgpack_pack_real_int16
#undef msgpack_pack_real_int32
#undef msgpack_pack_real_int64

@ -0,0 +1,194 @@
/*
* MessagePack system dependencies
*
* Copyright (C) 2008-2010 FURUHASHI Sadayuki
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MSGPACK_SYSDEP_H__
#define MSGPACK_SYSDEP_H__
#include <stdlib.h>
#include <stddef.h>
#if defined(_MSC_VER) && _MSC_VER < 1600
typedef __int8 int8_t;
typedef unsigned __int8 uint8_t;
typedef __int16 int16_t;
typedef unsigned __int16 uint16_t;
typedef __int32 int32_t;
typedef unsigned __int32 uint32_t;
typedef __int64 int64_t;
typedef unsigned __int64 uint64_t;
#elif defined(_MSC_VER) // && _MSC_VER >= 1600
#include <stdint.h>
#else
#include <stdint.h>
#include <stdbool.h>
#endif
#ifdef _WIN32
#define _msgpack_atomic_counter_header <windows.h>
typedef long _msgpack_atomic_counter_t;
#define _msgpack_sync_decr_and_fetch(ptr) InterlockedDecrement(ptr)
#define _msgpack_sync_incr_and_fetch(ptr) InterlockedIncrement(ptr)
#elif defined(__GNUC__) && ((__GNUC__*10 + __GNUC_MINOR__) < 41)
#define _msgpack_atomic_counter_header "gcc_atomic.h"
#else
typedef unsigned int _msgpack_atomic_counter_t;
#define _msgpack_sync_decr_and_fetch(ptr) __sync_sub_and_fetch(ptr, 1)
#define _msgpack_sync_incr_and_fetch(ptr) __sync_add_and_fetch(ptr, 1)
#endif
#ifdef _WIN32
#ifdef __cplusplus
/* numeric_limits<T>::min,max */
#ifdef max
#undef max
#endif
#ifdef min
#undef min
#endif
#endif
#else
#include <arpa/inet.h> /* __BYTE_ORDER */
#endif
#if !defined(__LITTLE_ENDIAN__) && !defined(__BIG_ENDIAN__)
#if __BYTE_ORDER == __LITTLE_ENDIAN
#define __LITTLE_ENDIAN__
#elif __BYTE_ORDER == __BIG_ENDIAN
#define __BIG_ENDIAN__
#elif _WIN32
#define __LITTLE_ENDIAN__
#endif
#endif
#ifdef __LITTLE_ENDIAN__
#ifdef _WIN32
# if defined(ntohs)
# define _msgpack_be16(x) ntohs(x)
# elif defined(_byteswap_ushort) || (defined(_MSC_VER) && _MSC_VER >= 1400)
# define _msgpack_be16(x) ((uint16_t)_byteswap_ushort((unsigned short)x))
# else
# define _msgpack_be16(x) ( \
((((uint16_t)x) << 8) ) | \
((((uint16_t)x) >> 8) ) )
# endif
#else
# define _msgpack_be16(x) ntohs(x)
#endif
#ifdef _WIN32
# if defined(ntohl)
# define _msgpack_be32(x) ntohl(x)
# elif defined(_byteswap_ulong) || (defined(_MSC_VER) && _MSC_VER >= 1400)
# define _msgpack_be32(x) ((uint32_t)_byteswap_ulong((unsigned long)x))
# else
# define _msgpack_be32(x) \
( ((((uint32_t)x) << 24) ) | \
((((uint32_t)x) << 8) & 0x00ff0000U ) | \
((((uint32_t)x) >> 8) & 0x0000ff00U ) | \
((((uint32_t)x) >> 24) ) )
# endif
#else
# define _msgpack_be32(x) ntohl(x)
#endif
#if defined(_byteswap_uint64) || (defined(_MSC_VER) && _MSC_VER >= 1400)
# define _msgpack_be64(x) (_byteswap_uint64(x))
#elif defined(bswap_64)
# define _msgpack_be64(x) bswap_64(x)
#elif defined(__DARWIN_OSSwapInt64)
# define _msgpack_be64(x) __DARWIN_OSSwapInt64(x)
#else
#define _msgpack_be64(x) \
( ((((uint64_t)x) << 56) ) | \
((((uint64_t)x) << 40) & 0x00ff000000000000ULL ) | \
((((uint64_t)x) << 24) & 0x0000ff0000000000ULL ) | \
((((uint64_t)x) << 8) & 0x000000ff00000000ULL ) | \
((((uint64_t)x) >> 8) & 0x00000000ff000000ULL ) | \
((((uint64_t)x) >> 24) & 0x0000000000ff0000ULL ) | \
((((uint64_t)x) >> 40) & 0x000000000000ff00ULL ) | \
((((uint64_t)x) >> 56) ) )
#endif
#define _msgpack_load16(cast, from) ((cast)( \
(((uint16_t)((uint8_t*)(from))[0]) << 8) | \
(((uint16_t)((uint8_t*)(from))[1]) ) ))
#define _msgpack_load32(cast, from) ((cast)( \
(((uint32_t)((uint8_t*)(from))[0]) << 24) | \
(((uint32_t)((uint8_t*)(from))[1]) << 16) | \
(((uint32_t)((uint8_t*)(from))[2]) << 8) | \
(((uint32_t)((uint8_t*)(from))[3]) ) ))
#define _msgpack_load64(cast, from) ((cast)( \
(((uint64_t)((uint8_t*)(from))[0]) << 56) | \
(((uint64_t)((uint8_t*)(from))[1]) << 48) | \
(((uint64_t)((uint8_t*)(from))[2]) << 40) | \
(((uint64_t)((uint8_t*)(from))[3]) << 32) | \
(((uint64_t)((uint8_t*)(from))[4]) << 24) | \
(((uint64_t)((uint8_t*)(from))[5]) << 16) | \
(((uint64_t)((uint8_t*)(from))[6]) << 8) | \
(((uint64_t)((uint8_t*)(from))[7]) ) ))
#else
#define _msgpack_be16(x) (x)
#define _msgpack_be32(x) (x)
#define _msgpack_be64(x) (x)
#define _msgpack_load16(cast, from) ((cast)( \
(((uint16_t)((uint8_t*)from)[0]) << 8) | \
(((uint16_t)((uint8_t*)from)[1]) ) ))
#define _msgpack_load32(cast, from) ((cast)( \
(((uint32_t)((uint8_t*)from)[0]) << 24) | \
(((uint32_t)((uint8_t*)from)[1]) << 16) | \
(((uint32_t)((uint8_t*)from)[2]) << 8) | \
(((uint32_t)((uint8_t*)from)[3]) ) ))
#define _msgpack_load64(cast, from) ((cast)( \
(((uint64_t)((uint8_t*)from)[0]) << 56) | \
(((uint64_t)((uint8_t*)from)[1]) << 48) | \
(((uint64_t)((uint8_t*)from)[2]) << 40) | \
(((uint64_t)((uint8_t*)from)[3]) << 32) | \
(((uint64_t)((uint8_t*)from)[4]) << 24) | \
(((uint64_t)((uint8_t*)from)[5]) << 16) | \
(((uint64_t)((uint8_t*)from)[6]) << 8) | \
(((uint64_t)((uint8_t*)from)[7]) ) ))
#endif
#define _msgpack_store16(to, num) \
do { uint16_t val = _msgpack_be16(num); memcpy(to, &val, 2); } while(0)
#define _msgpack_store32(to, num) \
do { uint32_t val = _msgpack_be32(num); memcpy(to, &val, 4); } while(0)
#define _msgpack_store64(to, num) \
do { uint64_t val = _msgpack_be64(num); memcpy(to, &val, 8); } while(0)
/*
#define _msgpack_load16(cast, from) \
({ cast val; memcpy(&val, (char*)from, 2); _msgpack_be16(val); })
#define _msgpack_load32(cast, from) \
({ cast val; memcpy(&val, (char*)from, 4); _msgpack_be32(val); })
#define _msgpack_load64(cast, from) \
({ cast val; memcpy(&val, (char*)from, 8); _msgpack_be64(val); })
*/
#endif /* msgpack/sysdep.h */

@ -0,0 +1,391 @@
/*
* MessagePack for Python unpacking routine
*
* Copyright (C) 2009 Naoki INADA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define MSGPACK_EMBED_STACK_SIZE (1024)
#include "unpack_define.h"
typedef struct unpack_user {
bool use_list;
bool raw;
bool has_pairs_hook;
bool strict_map_key;
int timestamp;
PyObject *object_hook;
PyObject *list_hook;
PyObject *ext_hook;
PyObject *timestamp_t;
PyObject *giga;
PyObject *utc;
const char *unicode_errors;
Py_ssize_t max_str_len, max_bin_len, max_array_len, max_map_len, max_ext_len;
} unpack_user;
typedef PyObject* msgpack_unpack_object;
struct unpack_context;
typedef struct unpack_context unpack_context;
typedef int (*execute_fn)(unpack_context *ctx, const char* data, Py_ssize_t len, Py_ssize_t* off);
static inline msgpack_unpack_object unpack_callback_root(unpack_user* u)
{
return NULL;
}
static inline int unpack_callback_uint16(unpack_user* u, uint16_t d, msgpack_unpack_object* o)
{
PyObject *p = PyInt_FromLong((long)d);
if (!p)
return -1;
*o = p;
return 0;
}
static inline int unpack_callback_uint8(unpack_user* u, uint8_t d, msgpack_unpack_object* o)
{
return unpack_callback_uint16(u, d, o);
}
static inline int unpack_callback_uint32(unpack_user* u, uint32_t d, msgpack_unpack_object* o)
{
PyObject *p = PyInt_FromSize_t((size_t)d);
if (!p)
return -1;
*o = p;
return 0;
}
static inline int unpack_callback_uint64(unpack_user* u, uint64_t d, msgpack_unpack_object* o)
{
PyObject *p;
if (d > LONG_MAX) {
p = PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG)d);
} else {
p = PyInt_FromLong((long)d);
}
if (!p)
return -1;
*o = p;
return 0;
}
static inline int unpack_callback_int32(unpack_user* u, int32_t d, msgpack_unpack_object* o)
{
PyObject *p = PyInt_FromLong(d);
if (!p)
return -1;
*o = p;
return 0;
}
static inline int unpack_callback_int16(unpack_user* u, int16_t d, msgpack_unpack_object* o)
{
return unpack_callback_int32(u, d, o);
}
static inline int unpack_callback_int8(unpack_user* u, int8_t d, msgpack_unpack_object* o)
{
return unpack_callback_int32(u, d, o);
}
static inline int unpack_callback_int64(unpack_user* u, int64_t d, msgpack_unpack_object* o)
{
PyObject *p;
if (d > LONG_MAX || d < LONG_MIN) {
p = PyLong_FromLongLong((PY_LONG_LONG)d);
} else {
p = PyInt_FromLong((long)d);
}
*o = p;
return 0;
}
static inline int unpack_callback_double(unpack_user* u, double d, msgpack_unpack_object* o)
{
PyObject *p = PyFloat_FromDouble(d);
if (!p)
return -1;
*o = p;
return 0;
}
static inline int unpack_callback_float(unpack_user* u, float d, msgpack_unpack_object* o)
{
return unpack_callback_double(u, d, o);
}
static inline int unpack_callback_nil(unpack_user* u, msgpack_unpack_object* o)
{ Py_INCREF(Py_None); *o = Py_None; return 0; }
static inline int unpack_callback_true(unpack_user* u, msgpack_unpack_object* o)
{ Py_INCREF(Py_True); *o = Py_True; return 0; }
static inline int unpack_callback_false(unpack_user* u, msgpack_unpack_object* o)
{ Py_INCREF(Py_False); *o = Py_False; return 0; }
static inline int unpack_callback_array(unpack_user* u, unsigned int n, msgpack_unpack_object* o)
{
if (n > u->max_array_len) {
PyErr_Format(PyExc_ValueError, "%u exceeds max_array_len(%zd)", n, u->max_array_len);
return -1;
}
PyObject *p = u->use_list ? PyList_New(n) : PyTuple_New(n);
if (!p)
return -1;
*o = p;
return 0;
}
static inline int unpack_callback_array_item(unpack_user* u, unsigned int current, msgpack_unpack_object* c, msgpack_unpack_object o)
{
if (u->use_list)
PyList_SET_ITEM(*c, current, o);
else
PyTuple_SET_ITEM(*c, current, o);
return 0;
}
static inline int unpack_callback_array_end(unpack_user* u, msgpack_unpack_object* c)
{
if (u->list_hook) {
PyObject *new_c = PyObject_CallFunctionObjArgs(u->list_hook, *c, NULL);
if (!new_c)
return -1;
Py_DECREF(*c);
*c = new_c;
}
return 0;
}
static inline int unpack_callback_map(unpack_user* u, unsigned int n, msgpack_unpack_object* o)
{
if (n > u->max_map_len) {
PyErr_Format(PyExc_ValueError, "%u exceeds max_map_len(%zd)", n, u->max_map_len);
return -1;
}
PyObject *p;
if (u->has_pairs_hook) {
p = PyList_New(n); // Or use tuple?
}
else {
p = PyDict_New();
}
if (!p)
return -1;
*o = p;
return 0;
}
static inline int unpack_callback_map_item(unpack_user* u, unsigned int current, msgpack_unpack_object* c, msgpack_unpack_object k, msgpack_unpack_object v)
{
if (u->strict_map_key && !PyUnicode_CheckExact(k) && !PyBytes_CheckExact(k)) {
PyErr_Format(PyExc_ValueError, "%.100s is not allowed for map key", Py_TYPE(k)->tp_name);
return -1;
}
if (PyUnicode_CheckExact(k)) {
PyUnicode_InternInPlace(&k);
}
if (u->has_pairs_hook) {
msgpack_unpack_object item = PyTuple_Pack(2, k, v);
if (!item)
return -1;
Py_DECREF(k);
Py_DECREF(v);
PyList_SET_ITEM(*c, current, item);
return 0;
}
else if (PyDict_SetItem(*c, k, v) == 0) {
Py_DECREF(k);
Py_DECREF(v);
return 0;
}
return -1;
}
static inline int unpack_callback_map_end(unpack_user* u, msgpack_unpack_object* c)
{
if (u->object_hook) {
PyObject *new_c = PyObject_CallFunctionObjArgs(u->object_hook, *c, NULL);
if (!new_c)
return -1;
Py_DECREF(*c);
*c = new_c;
}
return 0;
}
static inline int unpack_callback_raw(unpack_user* u, const char* b, const char* p, unsigned int l, msgpack_unpack_object* o)
{
if (l > u->max_str_len) {
PyErr_Format(PyExc_ValueError, "%u exceeds max_str_len(%zd)", l, u->max_str_len);
return -1;
}
PyObject *py;
if (u->raw) {
py = PyBytes_FromStringAndSize(p, l);
} else {
py = PyUnicode_DecodeUTF8(p, l, u->unicode_errors);
}
if (!py)
return -1;
*o = py;
return 0;
}
static inline int unpack_callback_bin(unpack_user* u, const char* b, const char* p, unsigned int l, msgpack_unpack_object* o)
{
if (l > u->max_bin_len) {
PyErr_Format(PyExc_ValueError, "%u exceeds max_bin_len(%zd)", l, u->max_bin_len);
return -1;
}
PyObject *py = PyBytes_FromStringAndSize(p, l);
if (!py)
return -1;
*o = py;
return 0;
}
typedef struct msgpack_timestamp {
int64_t tv_sec;
uint32_t tv_nsec;
} msgpack_timestamp;
/*
* Unpack ext buffer to a timestamp. Pulled from msgpack-c timestamp.h.
*/
static int unpack_timestamp(const char* buf, unsigned int buflen, msgpack_timestamp* ts) {
switch (buflen) {
case 4:
ts->tv_nsec = 0;
{
uint32_t v = _msgpack_load32(uint32_t, buf);
ts->tv_sec = (int64_t)v;
}
return 0;
case 8: {
uint64_t value =_msgpack_load64(uint64_t, buf);
ts->tv_nsec = (uint32_t)(value >> 34);
ts->tv_sec = value & 0x00000003ffffffffLL;
return 0;
}
case 12:
ts->tv_nsec = _msgpack_load32(uint32_t, buf);
ts->tv_sec = _msgpack_load64(int64_t, buf + 4);
return 0;
default:
return -1;
}
}
#include "datetime.h"
static int unpack_callback_ext(unpack_user* u, const char* base, const char* pos,
unsigned int length, msgpack_unpack_object* o)
{
int8_t typecode = (int8_t)*pos++;
if (!u->ext_hook) {
PyErr_SetString(PyExc_AssertionError, "u->ext_hook cannot be NULL");
return -1;
}
if (length-1 > u->max_ext_len) {
PyErr_Format(PyExc_ValueError, "%u exceeds max_ext_len(%zd)", length, u->max_ext_len);
return -1;
}
PyObject *py = NULL;
// length also includes the typecode, so the actual data is length-1
if (typecode == -1) {
msgpack_timestamp ts;
if (unpack_timestamp(pos, length-1, &ts) < 0) {
return -1;
}
if (u->timestamp == 2) { // int
PyObject *a = PyLong_FromLongLong(ts.tv_sec);
if (a == NULL) return -1;
PyObject *c = PyNumber_Multiply(a, u->giga);
Py_DECREF(a);
if (c == NULL) {
return -1;
}
PyObject *b = PyLong_FromUnsignedLong(ts.tv_nsec);
if (b == NULL) {
Py_DECREF(c);
return -1;
}
py = PyNumber_Add(c, b);
Py_DECREF(c);
Py_DECREF(b);
}
else if (u->timestamp == 0) { // Timestamp
py = PyObject_CallFunction(u->timestamp_t, "(Lk)", ts.tv_sec, ts.tv_nsec);
}
else if (u->timestamp == 3) { // datetime
// Calculate datetime using epoch + delta
// due to limitations PyDateTime_FromTimestamp on Windows with negative timestamps
PyObject *epoch = PyDateTimeAPI->DateTime_FromDateAndTime(1970, 1, 1, 0, 0, 0, 0, u->utc, PyDateTimeAPI->DateTimeType);
if (epoch == NULL) {
return -1;
}
PyObject* d = PyDelta_FromDSU(ts.tv_sec/(24*3600), ts.tv_sec%(24*3600), ts.tv_nsec / 1000);
if (d == NULL) {
Py_DECREF(epoch);
return -1;
}
py = PyNumber_Add(epoch, d);
Py_DECREF(epoch);
Py_DECREF(d);
}
else { // float
PyObject *a = PyFloat_FromDouble((double)ts.tv_nsec);
if (a == NULL) return -1;
PyObject *b = PyNumber_TrueDivide(a, u->giga);
Py_DECREF(a);
if (b == NULL) return -1;
PyObject *c = PyLong_FromLongLong(ts.tv_sec);
if (c == NULL) {
Py_DECREF(b);
return -1;
}
a = PyNumber_Add(b, c);
Py_DECREF(b);
Py_DECREF(c);
py = a;
}
} else {
py = PyObject_CallFunction(u->ext_hook, "(iy#)", (int)typecode, pos, (Py_ssize_t)length-1);
}
if (!py)
return -1;
*o = py;
return 0;
}
#include "unpack_template.h"

@ -0,0 +1,95 @@
/*
* MessagePack unpacking routine template
*
* Copyright (C) 2008-2010 FURUHASHI Sadayuki
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MSGPACK_UNPACK_DEFINE_H__
#define MSGPACK_UNPACK_DEFINE_H__
#include "msgpack/sysdep.h"
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <stdio.h>
#ifdef __cplusplus
extern "C" {
#endif
#ifndef MSGPACK_EMBED_STACK_SIZE
#define MSGPACK_EMBED_STACK_SIZE 32
#endif
// CS is first byte & 0x1f
typedef enum {
CS_HEADER = 0x00, // nil
//CS_ = 0x01,
//CS_ = 0x02, // false
//CS_ = 0x03, // true
CS_BIN_8 = 0x04,
CS_BIN_16 = 0x05,
CS_BIN_32 = 0x06,
CS_EXT_8 = 0x07,
CS_EXT_16 = 0x08,
CS_EXT_32 = 0x09,
CS_FLOAT = 0x0a,
CS_DOUBLE = 0x0b,
CS_UINT_8 = 0x0c,
CS_UINT_16 = 0x0d,
CS_UINT_32 = 0x0e,
CS_UINT_64 = 0x0f,
CS_INT_8 = 0x10,
CS_INT_16 = 0x11,
CS_INT_32 = 0x12,
CS_INT_64 = 0x13,
//CS_FIXEXT1 = 0x14,
//CS_FIXEXT2 = 0x15,
//CS_FIXEXT4 = 0x16,
//CS_FIXEXT8 = 0x17,
//CS_FIXEXT16 = 0x18,
CS_RAW_8 = 0x19,
CS_RAW_16 = 0x1a,
CS_RAW_32 = 0x1b,
CS_ARRAY_16 = 0x1c,
CS_ARRAY_32 = 0x1d,
CS_MAP_16 = 0x1e,
CS_MAP_32 = 0x1f,
ACS_RAW_VALUE,
ACS_BIN_VALUE,
ACS_EXT_VALUE,
} msgpack_unpack_state;
typedef enum {
CT_ARRAY_ITEM,
CT_MAP_KEY,
CT_MAP_VALUE,
} msgpack_container_type;
#ifdef __cplusplus
}
#endif
#endif /* msgpack/unpack_define.h */

@ -0,0 +1,454 @@
/*
* MessagePack unpacking routine template
*
* Copyright (C) 2008-2010 FURUHASHI Sadayuki
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef USE_CASE_RANGE
#if !defined(_MSC_VER)
#define USE_CASE_RANGE
#endif
#endif
typedef struct unpack_stack {
PyObject* obj;
Py_ssize_t size;
Py_ssize_t count;
unsigned int ct;
PyObject* map_key;
} unpack_stack;
struct unpack_context {
unpack_user user;
unsigned int cs;
unsigned int trail;
unsigned int top;
/*
unpack_stack* stack;
unsigned int stack_size;
unpack_stack embed_stack[MSGPACK_EMBED_STACK_SIZE];
*/
unpack_stack stack[MSGPACK_EMBED_STACK_SIZE];
};
static inline void unpack_init(unpack_context* ctx)
{
ctx->cs = CS_HEADER;
ctx->trail = 0;
ctx->top = 0;
/*
ctx->stack = ctx->embed_stack;
ctx->stack_size = MSGPACK_EMBED_STACK_SIZE;
*/
ctx->stack[0].obj = unpack_callback_root(&ctx->user);
}
/*
static inline void unpack_destroy(unpack_context* ctx)
{
if(ctx->stack_size != MSGPACK_EMBED_STACK_SIZE) {
free(ctx->stack);
}
}
*/
static inline PyObject* unpack_data(unpack_context* ctx)
{
return (ctx)->stack[0].obj;
}
static inline void unpack_clear(unpack_context *ctx)
{
Py_CLEAR(ctx->stack[0].obj);
}
template <bool construct>
static inline int unpack_execute(unpack_context* ctx, const char* data, Py_ssize_t len, Py_ssize_t* off)
{
assert(len >= *off);
const unsigned char* p = (unsigned char*)data + *off;
const unsigned char* const pe = (unsigned char*)data + len;
const void* n = p;
unsigned int trail = ctx->trail;
unsigned int cs = ctx->cs;
unsigned int top = ctx->top;
unpack_stack* stack = ctx->stack;
/*
unsigned int stack_size = ctx->stack_size;
*/
unpack_user* user = &ctx->user;
PyObject* obj = NULL;
unpack_stack* c = NULL;
int ret;
#define construct_cb(name) \
construct && unpack_callback ## name
#define push_simple_value(func) \
if(construct_cb(func)(user, &obj) < 0) { goto _failed; } \
goto _push
#define push_fixed_value(func, arg) \
if(construct_cb(func)(user, arg, &obj) < 0) { goto _failed; } \
goto _push
#define push_variable_value(func, base, pos, len) \
if(construct_cb(func)(user, \
(const char*)base, (const char*)pos, len, &obj) < 0) { goto _failed; } \
goto _push
#define again_fixed_trail(_cs, trail_len) \
trail = trail_len; \
cs = _cs; \
goto _fixed_trail_again
#define again_fixed_trail_if_zero(_cs, trail_len, ifzero) \
trail = trail_len; \
if(trail == 0) { goto ifzero; } \
cs = _cs; \
goto _fixed_trail_again
#define start_container(func, count_, ct_) \
if(top >= MSGPACK_EMBED_STACK_SIZE) { ret = -3; goto _end; } \
if(construct_cb(func)(user, count_, &stack[top].obj) < 0) { goto _failed; } \
if((count_) == 0) { obj = stack[top].obj; \
if (construct_cb(func##_end)(user, &obj) < 0) { goto _failed; } \
goto _push; } \
stack[top].ct = ct_; \
stack[top].size = count_; \
stack[top].count = 0; \
++top; \
goto _header_again
#define NEXT_CS(p) ((unsigned int)*p & 0x1f)
#ifdef USE_CASE_RANGE
#define SWITCH_RANGE_BEGIN switch(*p) {
#define SWITCH_RANGE(FROM, TO) case FROM ... TO:
#define SWITCH_RANGE_DEFAULT default:
#define SWITCH_RANGE_END }
#else
#define SWITCH_RANGE_BEGIN { if(0) {
#define SWITCH_RANGE(FROM, TO) } else if(FROM <= *p && *p <= TO) {
#define SWITCH_RANGE_DEFAULT } else {
#define SWITCH_RANGE_END } }
#endif
if(p == pe) { goto _out; }
do {
switch(cs) {
case CS_HEADER:
SWITCH_RANGE_BEGIN
SWITCH_RANGE(0x00, 0x7f) // Positive Fixnum
push_fixed_value(_uint8, *(uint8_t*)p);
SWITCH_RANGE(0xe0, 0xff) // Negative Fixnum
push_fixed_value(_int8, *(int8_t*)p);
SWITCH_RANGE(0xc0, 0xdf) // Variable
switch(*p) {
case 0xc0: // nil
push_simple_value(_nil);
//case 0xc1: // never used
case 0xc2: // false
push_simple_value(_false);
case 0xc3: // true
push_simple_value(_true);
case 0xc4: // bin 8
again_fixed_trail(NEXT_CS(p), 1);
case 0xc5: // bin 16
again_fixed_trail(NEXT_CS(p), 2);
case 0xc6: // bin 32
again_fixed_trail(NEXT_CS(p), 4);
case 0xc7: // ext 8
again_fixed_trail(NEXT_CS(p), 1);
case 0xc8: // ext 16
again_fixed_trail(NEXT_CS(p), 2);
case 0xc9: // ext 32
again_fixed_trail(NEXT_CS(p), 4);
case 0xca: // float
case 0xcb: // double
case 0xcc: // unsigned int 8
case 0xcd: // unsigned int 16
case 0xce: // unsigned int 32
case 0xcf: // unsigned int 64
case 0xd0: // signed int 8
case 0xd1: // signed int 16
case 0xd2: // signed int 32
case 0xd3: // signed int 64
again_fixed_trail(NEXT_CS(p), 1 << (((unsigned int)*p) & 0x03));
case 0xd4: // fixext 1
case 0xd5: // fixext 2
case 0xd6: // fixext 4
case 0xd7: // fixext 8
again_fixed_trail_if_zero(ACS_EXT_VALUE,
(1 << (((unsigned int)*p) & 0x03))+1,
_ext_zero);
case 0xd8: // fixext 16
again_fixed_trail_if_zero(ACS_EXT_VALUE, 16+1, _ext_zero);
case 0xd9: // str 8
again_fixed_trail(NEXT_CS(p), 1);
case 0xda: // raw 16
case 0xdb: // raw 32
case 0xdc: // array 16
case 0xdd: // array 32
case 0xde: // map 16
case 0xdf: // map 32
again_fixed_trail(NEXT_CS(p), 2 << (((unsigned int)*p) & 0x01));
default:
ret = -2;
goto _end;
}
SWITCH_RANGE(0xa0, 0xbf) // FixRaw
again_fixed_trail_if_zero(ACS_RAW_VALUE, ((unsigned int)*p & 0x1f), _raw_zero);
SWITCH_RANGE(0x90, 0x9f) // FixArray
start_container(_array, ((unsigned int)*p) & 0x0f, CT_ARRAY_ITEM);
SWITCH_RANGE(0x80, 0x8f) // FixMap
start_container(_map, ((unsigned int)*p) & 0x0f, CT_MAP_KEY);
SWITCH_RANGE_DEFAULT
ret = -2;
goto _end;
SWITCH_RANGE_END
// end CS_HEADER
_fixed_trail_again:
++p;
default:
if((size_t)(pe - p) < trail) { goto _out; }
n = p; p += trail - 1;
switch(cs) {
case CS_EXT_8:
again_fixed_trail_if_zero(ACS_EXT_VALUE, *(uint8_t*)n+1, _ext_zero);
case CS_EXT_16:
again_fixed_trail_if_zero(ACS_EXT_VALUE,
_msgpack_load16(uint16_t,n)+1,
_ext_zero);
case CS_EXT_32:
again_fixed_trail_if_zero(ACS_EXT_VALUE,
_msgpack_load32(uint32_t,n)+1,
_ext_zero);
case CS_FLOAT: {
double f = _PyFloat_Unpack4((unsigned char*)n, 0);
push_fixed_value(_float, f); }
case CS_DOUBLE: {
double f = _PyFloat_Unpack8((unsigned char*)n, 0);
push_fixed_value(_double, f); }
case CS_UINT_8:
push_fixed_value(_uint8, *(uint8_t*)n);
case CS_UINT_16:
push_fixed_value(_uint16, _msgpack_load16(uint16_t,n));
case CS_UINT_32:
push_fixed_value(_uint32, _msgpack_load32(uint32_t,n));
case CS_UINT_64:
push_fixed_value(_uint64, _msgpack_load64(uint64_t,n));
case CS_INT_8:
push_fixed_value(_int8, *(int8_t*)n);
case CS_INT_16:
push_fixed_value(_int16, _msgpack_load16(int16_t,n));
case CS_INT_32:
push_fixed_value(_int32, _msgpack_load32(int32_t,n));
case CS_INT_64:
push_fixed_value(_int64, _msgpack_load64(int64_t,n));
case CS_BIN_8:
again_fixed_trail_if_zero(ACS_BIN_VALUE, *(uint8_t*)n, _bin_zero);
case CS_BIN_16:
again_fixed_trail_if_zero(ACS_BIN_VALUE, _msgpack_load16(uint16_t,n), _bin_zero);
case CS_BIN_32:
again_fixed_trail_if_zero(ACS_BIN_VALUE, _msgpack_load32(uint32_t,n), _bin_zero);
case ACS_BIN_VALUE:
_bin_zero:
push_variable_value(_bin, data, n, trail);
case CS_RAW_8:
again_fixed_trail_if_zero(ACS_RAW_VALUE, *(uint8_t*)n, _raw_zero);
case CS_RAW_16:
again_fixed_trail_if_zero(ACS_RAW_VALUE, _msgpack_load16(uint16_t,n), _raw_zero);
case CS_RAW_32:
again_fixed_trail_if_zero(ACS_RAW_VALUE, _msgpack_load32(uint32_t,n), _raw_zero);
case ACS_RAW_VALUE:
_raw_zero:
push_variable_value(_raw, data, n, trail);
case ACS_EXT_VALUE:
_ext_zero:
push_variable_value(_ext, data, n, trail);
case CS_ARRAY_16:
start_container(_array, _msgpack_load16(uint16_t,n), CT_ARRAY_ITEM);
case CS_ARRAY_32:
/* FIXME security guard */
start_container(_array, _msgpack_load32(uint32_t,n), CT_ARRAY_ITEM);
case CS_MAP_16:
start_container(_map, _msgpack_load16(uint16_t,n), CT_MAP_KEY);
case CS_MAP_32:
/* FIXME security guard */
start_container(_map, _msgpack_load32(uint32_t,n), CT_MAP_KEY);
default:
goto _failed;
}
}
_push:
if(top == 0) { goto _finish; }
c = &stack[top-1];
switch(c->ct) {
case CT_ARRAY_ITEM:
if(construct_cb(_array_item)(user, c->count, &c->obj, obj) < 0) { goto _failed; }
if(++c->count == c->size) {
obj = c->obj;
if (construct_cb(_array_end)(user, &obj) < 0) { goto _failed; }
--top;
/*printf("stack pop %d\n", top);*/
goto _push;
}
goto _header_again;
case CT_MAP_KEY:
c->map_key = obj;
c->ct = CT_MAP_VALUE;
goto _header_again;
case CT_MAP_VALUE:
if(construct_cb(_map_item)(user, c->count, &c->obj, c->map_key, obj) < 0) { goto _failed; }
if(++c->count == c->size) {
obj = c->obj;
if (construct_cb(_map_end)(user, &obj) < 0) { goto _failed; }
--top;
/*printf("stack pop %d\n", top);*/
goto _push;
}
c->ct = CT_MAP_KEY;
goto _header_again;
default:
goto _failed;
}
_header_again:
cs = CS_HEADER;
++p;
} while(p != pe);
goto _out;
_finish:
if (!construct)
unpack_callback_nil(user, &obj);
stack[0].obj = obj;
++p;
ret = 1;
/*printf("-- finish --\n"); */
goto _end;
_failed:
/*printf("** FAILED **\n"); */
ret = -1;
goto _end;
_out:
ret = 0;
goto _end;
_end:
ctx->cs = cs;
ctx->trail = trail;
ctx->top = top;
*off = p - (const unsigned char*)data;
return ret;
#undef construct_cb
}
#undef SWITCH_RANGE_BEGIN
#undef SWITCH_RANGE
#undef SWITCH_RANGE_DEFAULT
#undef SWITCH_RANGE_END
#undef push_simple_value
#undef push_fixed_value
#undef push_variable_value
#undef again_fixed_trail
#undef again_fixed_trail_if_zero
#undef start_container
template <unsigned int fixed_offset, unsigned int var_offset>
static inline int unpack_container_header(unpack_context* ctx, const char* data, Py_ssize_t len, Py_ssize_t* off)
{
assert(len >= *off);
uint32_t size;
const unsigned char *const p = (unsigned char*)data + *off;
#define inc_offset(inc) \
if (len - *off < inc) \
return 0; \
*off += inc;
switch (*p) {
case var_offset:
inc_offset(3);
size = _msgpack_load16(uint16_t, p + 1);
break;
case var_offset + 1:
inc_offset(5);
size = _msgpack_load32(uint32_t, p + 1);
break;
#ifdef USE_CASE_RANGE
case fixed_offset + 0x0 ... fixed_offset + 0xf:
#else
case fixed_offset + 0x0:
case fixed_offset + 0x1:
case fixed_offset + 0x2:
case fixed_offset + 0x3:
case fixed_offset + 0x4:
case fixed_offset + 0x5:
case fixed_offset + 0x6:
case fixed_offset + 0x7:
case fixed_offset + 0x8:
case fixed_offset + 0x9:
case fixed_offset + 0xa:
case fixed_offset + 0xb:
case fixed_offset + 0xc:
case fixed_offset + 0xd:
case fixed_offset + 0xe:
case fixed_offset + 0xf:
#endif
++*off;
size = ((unsigned int)*p) & 0x0f;
break;
default:
PyErr_SetString(PyExc_ValueError, "Unexpected type header on stream");
return -1;
}
unpack_callback_uint32(&ctx->user, size, &ctx->stack[0].obj);
return 1;
}
#undef SWITCH_RANGE_BEGIN
#undef SWITCH_RANGE
#undef SWITCH_RANGE_DEFAULT
#undef SWITCH_RANGE_END
static const execute_fn unpack_construct = &unpack_execute<true>;
static const execute_fn unpack_skip = &unpack_execute<false>;
static const execute_fn read_array_header = &unpack_container_header<0x90, 0xdc>;
static const execute_fn read_map_header = &unpack_container_header<0x80, 0xde>;
#undef NEXT_CS
/* vim: set ts=4 sw=4 sts=4 expandtab */

@ -9,14 +9,14 @@
Requests HTTP Library
~~~~~~~~~~~~~~~~~~~~~
Requests is an HTTP library, written in Python, for human beings. Basic GET
usage:
Requests is an HTTP library, written in Python, for human beings.
Basic GET usage:
>>> import requests
>>> r = requests.get('https://www.python.org')
>>> r.status_code
200
>>> 'Python is a programming language' in r.content
>>> b'Python is a programming language' in r.content
True
... or POST:
@ -27,14 +27,14 @@ usage:
{
...
"form": {
"key2": "value2",
"key1": "value1"
"key1": "value1",
"key2": "value2"
},
...
}
The other HTTP methods are supported - see `requests.api`. Full documentation
is at <http://python-requests.org>.
is at <https://requests.readthedocs.io>.
:copyright: (c) 2017 by Kenneth Reitz.
:license: Apache 2.0, see LICENSE for more details.
@ -57,18 +57,16 @@ def check_compatibility(urllib3_version, chardet_version):
# Check urllib3 for compatibility.
major, minor, patch = urllib3_version # noqa: F811
major, minor, patch = int(major), int(minor), int(patch)
# urllib3 >= 1.21.1, <= 1.25
# urllib3 >= 1.21.1, <= 1.26
assert major == 1
assert minor >= 21
assert minor <= 25
assert minor <= 26
# Check chardet for compatibility.
major, minor, patch = chardet_version.split('.')[:3]
major, minor, patch = int(major), int(minor), int(patch)
# chardet >= 3.0.2, < 3.1.0
assert major == 3
assert minor < 1
assert patch >= 2
# chardet >= 3.0.2, < 5.0.0
assert (3, 0, 2) <= (major, minor, patch) < (5, 0, 0)
def _check_cryptography(cryptography_version):
@ -90,14 +88,22 @@ except (AssertionError, ValueError):
"version!".format(urllib3.__version__, chardet.__version__),
RequestsDependencyWarning)
# Attempt to enable urllib3's SNI support, if possible
# Attempt to enable urllib3's fallback for SNI support
# if the standard library doesn't support SNI or the
# 'ssl' library isn't available.
try:
from urllib3.contrib import pyopenssl
pyopenssl.inject_into_urllib3()
try:
import ssl
except ImportError:
ssl = None
if not getattr(ssl, "HAS_SNI", False):
from urllib3.contrib import pyopenssl
pyopenssl.inject_into_urllib3()
# Check cryptography version
from cryptography import __version__ as cryptography_version
_check_cryptography(cryptography_version)
# Check cryptography version
from cryptography import __version__ as cryptography_version
_check_cryptography(cryptography_version)
except ImportError:
pass

@ -4,11 +4,11 @@
__title__ = 'requests'
__description__ = 'Python HTTP for Humans.'
__url__ = 'http://python-requests.org'
__version__ = '2.22.0'
__build__ = 0x022200
__url__ = 'https://requests.readthedocs.io'
__version__ = '2.25.1'
__build__ = 0x022501
__author__ = 'Kenneth Reitz'
__author_email__ = 'me@kennethreitz.org'
__license__ = 'Apache 2.0'
__copyright__ = 'Copyright 2019 Kenneth Reitz'
__copyright__ = 'Copyright 2020 Kenneth Reitz'
__cake__ = u'\u2728 \U0001f370 \u2728'

@ -16,7 +16,7 @@ from . import sessions
def request(method, url, **kwargs):
"""Constructs and sends a :class:`Request <Request>`.
:param method: method for the new :class:`Request` object.
:param method: method for the new :class:`Request` object: ``GET``, ``OPTIONS``, ``HEAD``, ``POST``, ``PUT``, ``PATCH``, or ``DELETE``.
:param url: URL for the new :class:`Request` object.
:param params: (optional) Dictionary, list of tuples or bytes to send
in the query string for the :class:`Request`.
@ -50,6 +50,7 @@ def request(method, url, **kwargs):
>>> import requests
>>> req = requests.request('GET', 'https://httpbin.org/get')
>>> req
<Response [200]>
"""
@ -92,7 +93,9 @@ def head(url, **kwargs):
r"""Sends a HEAD request.
:param url: URL for the new :class:`Request` object.
:param \*\*kwargs: Optional arguments that ``request`` takes.
:param \*\*kwargs: Optional arguments that ``request`` takes. If
`allow_redirects` is not provided, it will be set to `False` (as
opposed to the default :meth:`request` behavior).
:return: :class:`Response <Response>` object
:rtype: requests.Response
"""

@ -50,7 +50,7 @@ def _basic_auth_str(username, password):
"Non-string passwords will no longer be supported in Requests "
"3.0.0. Please convert the object you've passed in ({!r}) to "
"a string or bytes object in the near future to avoid "
"problems.".format(password),
"problems.".format(type(password)),
category=DeprecationWarning,
)
password = str(password)
@ -239,7 +239,7 @@ class HTTPDigestAuth(AuthBase):
"""
# If response is not 4xx, do not auth
# See https://github.com/requests/requests/issues/3772
# See https://github.com/psf/requests/issues/3772
if not 400 <= r.status_code < 500:
self._thread_local.num_401_calls = 1
return r

@ -43,6 +43,7 @@ if is_py2:
import cookielib
from Cookie import Morsel
from StringIO import StringIO
# Keep OrderedDict for backwards compatibility.
from collections import Callable, Mapping, MutableMapping, OrderedDict
@ -59,6 +60,7 @@ elif is_py3:
from http import cookiejar as cookielib
from http.cookies import Morsel
from io import StringIO
# Keep OrderedDict for backwards compatibility.
from collections import OrderedDict
from collections.abc import Callable, Mapping, MutableMapping

@ -94,11 +94,11 @@ class ChunkedEncodingError(RequestException):
class ContentDecodingError(RequestException, BaseHTTPError):
"""Failed to decode response content"""
"""Failed to decode response content."""
class StreamConsumedError(RequestException, TypeError):
"""The content for this response was already consumed"""
"""The content for this response was already consumed."""
class RetryError(RequestException):
@ -106,21 +106,18 @@ class RetryError(RequestException):
class UnrewindableBodyError(RequestException):
"""Requests encountered an error when trying to rewind a body"""
"""Requests encountered an error when trying to rewind a body."""
# Warnings
class RequestsWarning(Warning):
"""Base warning for Requests."""
pass
class FileModeWarning(RequestsWarning, DeprecationWarning):
"""A file was opened in text mode, but Requests determined its binary length."""
pass
class RequestsDependencyWarning(RequestsWarning):
"""An imported dependency doesn't match the expected version range."""
pass

@ -12,7 +12,7 @@ import sys
# Import encoding now, to avoid implicit import later.
# Implicit import within threads may cause LookupError when standard library is in a ZIP,
# such as in Embedded Python. See https://github.com/requests/requests/issues/3578.
# such as in Embedded Python. See https://github.com/psf/requests/issues/3578.
import encodings.idna
from urllib3.fields import RequestField
@ -273,13 +273,16 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
"""The fully mutable :class:`PreparedRequest <PreparedRequest>` object,
containing the exact bytes that will be sent to the server.
Generated from either a :class:`Request <Request>` object or manually.
Instances are generated from a :class:`Request <Request>` object, and
should not be instantiated manually; doing so may produce undesirable
effects.
Usage::
>>> import requests
>>> req = requests.Request('GET', 'https://httpbin.org/get')
>>> r = req.prepare()
>>> r
<PreparedRequest [GET]>
>>> s = requests.Session()
@ -358,7 +361,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
#: We're unable to blindly call unicode/str functions
#: as this will include the bytestring indicator (b'')
#: on python 3.x.
#: https://github.com/requests/requests/pull/2238
#: https://github.com/psf/requests/pull/2238
if isinstance(url, bytes):
url = url.decode('utf8')
else:
@ -472,12 +475,12 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
not isinstance(data, (basestring, list, tuple, Mapping))
])
try:
length = super_len(data)
except (TypeError, AttributeError, UnsupportedOperation):
length = None
if is_stream:
try:
length = super_len(data)
except (TypeError, AttributeError, UnsupportedOperation):
length = None
body = data
if getattr(body, 'tell', None) is not None:
@ -608,7 +611,7 @@ class Response(object):
#: File-like object representation of response (for advanced usage).
#: Use of ``raw`` requires that ``stream=True`` be set on the request.
# This requirement does not apply for use internally to Requests.
#: This requirement does not apply for use internally to Requests.
self.raw = None
#: Final URL location of Response.
@ -915,7 +918,7 @@ class Response(object):
return l
def raise_for_status(self):
"""Raises stored :class:`HTTPError`, if one occurred."""
"""Raises :class:`HTTPError`, if one occurred."""
http_error_msg = ''
if isinstance(self.reason, bytes):

@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
requests.session
~~~~~~~~~~~~~~~~
requests.sessions
~~~~~~~~~~~~~~~~~
This module provides a Session object to manage and persist settings across
requests (cookies, auth, proxies).
@ -11,9 +11,10 @@ import os
import sys
import time
from datetime import timedelta
from collections import OrderedDict
from .auth import _basic_auth_str
from .compat import cookielib, is_py3, OrderedDict, urljoin, urlparse, Mapping
from .compat import cookielib, is_py3, urljoin, urlparse, Mapping
from .cookies import (
cookiejar_from_dict, extract_cookies_to_jar, RequestsCookieJar, merge_cookies)
from .models import Request, PreparedRequest, DEFAULT_REDIRECT_LIMIT
@ -162,7 +163,7 @@ class SessionRedirectMixin(object):
resp.raw.read(decode_content=False)
if len(resp.history) >= self.max_redirects:
raise TooManyRedirects('Exceeded %s redirects.' % self.max_redirects, response=resp)
raise TooManyRedirects('Exceeded {} redirects.'.format(self.max_redirects), response=resp)
# Release the connection back into the pool.
resp.close()
@ -170,7 +171,7 @@ class SessionRedirectMixin(object):
# Handle redirection without scheme (see: RFC 1808 Section 4)
if url.startswith('//'):
parsed_rurl = urlparse(resp.url)
url = '%s:%s' % (to_native_string(parsed_rurl.scheme), url)
url = ':'.join([to_native_string(parsed_rurl.scheme), url])
# Normalize url case and attach previous fragment if needed (RFC 7231 7.1.2)
parsed = urlparse(url)
@ -192,19 +193,16 @@ class SessionRedirectMixin(object):
self.rebuild_method(prepared_request, resp)
# https://github.com/requests/requests/issues/1084
# https://github.com/psf/requests/issues/1084
if resp.status_code not in (codes.temporary_redirect, codes.permanent_redirect):
# https://github.com/requests/requests/issues/3490
# https://github.com/psf/requests/issues/3490
purged_headers = ('Content-Length', 'Content-Type', 'Transfer-Encoding')
for header in purged_headers:
prepared_request.headers.pop(header, None)
prepared_request.body = None
headers = prepared_request.headers
try:
del headers['Cookie']
except KeyError:
pass
headers.pop('Cookie', None)
# Extract any cookies sent on the response to the cookiejar
# in the new request. Because we've mutated our copied prepared
@ -271,7 +269,6 @@ class SessionRedirectMixin(object):
if new_auth is not None:
prepared_request.prepare_auth(new_auth)
return
def rebuild_proxies(self, prepared_request, proxies):
"""This method re-evaluates the proxy configuration by considering the
@ -352,13 +349,13 @@ class Session(SessionRedirectMixin):
Or as a context manager::
>>> with requests.Session() as s:
>>> s.get('https://httpbin.org/get')
... s.get('https://httpbin.org/get')
<Response [200]>
"""
__attrs__ = [
'headers', 'cookies', 'auth', 'proxies', 'hooks', 'params', 'verify',
'cert', 'prefetch', 'adapters', 'stream', 'trust_env',
'cert', 'adapters', 'stream', 'trust_env',
'max_redirects',
]
@ -390,6 +387,13 @@ class Session(SessionRedirectMixin):
self.stream = False
#: SSL Verification default.
#: Defaults to `True`, requiring requests to verify the TLS certificate at the
#: remote end.
#: If verify is set to `False`, requests will accept any TLS certificate
#: presented by the server, and will ignore hostname mismatches and/or
#: expired certificates, which will make your application vulnerable to
#: man-in-the-middle (MitM) attacks.
#: Only set this to `False` for testing.
self.verify = True
#: SSL client certificate default, if String, path to ssl client
@ -498,7 +502,12 @@ class Session(SessionRedirectMixin):
content. Defaults to ``False``.
:param verify: (optional) Either a boolean, in which case it controls whether we verify
the server's TLS certificate, or a string, in which case it must be a path
to a CA bundle to use. Defaults to ``True``.
to a CA bundle to use. Defaults to ``True``. When set to
``False``, requests will accept any TLS certificate presented by
the server, and will ignore hostname mismatches and/or expired
certificates, which will make your application vulnerable to
man-in-the-middle (MitM) attacks. Setting verify to ``False``
may be useful during local development or testing.
:param cert: (optional) if String, path to ssl client cert file (.pem).
If Tuple, ('cert', 'key') pair.
:rtype: requests.Response
@ -661,11 +670,13 @@ class Session(SessionRedirectMixin):
extract_cookies_to_jar(self.cookies, request, r.raw)
# Redirect resolving generator.
gen = self.resolve_redirects(r, request, **kwargs)
# Resolve redirects if allowed.
history = [resp for resp in gen] if allow_redirects else []
if allow_redirects:
# Redirect resolving generator.
gen = self.resolve_redirects(r, request, **kwargs)
history = [resp for resp in gen]
else:
history = []
# Shuffle things around if there's history.
if history:
@ -728,7 +739,7 @@ class Session(SessionRedirectMixin):
return adapter
# Nothing matches :-/
raise InvalidSchema("No connection adapters were found for '%s'" % url)
raise InvalidSchema("No connection adapters were found for {!r}".format(url))
def close(self):
"""Closes all adapters and as such the session"""

@ -5,12 +5,15 @@ The ``codes`` object defines a mapping from common names for HTTP statuses
to their numerical codes, accessible either as attributes or as dictionary
items.
>>> requests.codes['temporary_redirect']
307
>>> requests.codes.teapot
418
>>> requests.codes['\o/']
200
Example::
>>> import requests
>>> requests.codes['temporary_redirect']
307
>>> requests.codes.teapot
418
>>> requests.codes['\o/']
200
Some codes have multiple names, and both upper- and lower-case versions of
the names are allowed. For example, ``codes.ok``, ``codes.OK``, and

@ -7,7 +7,9 @@ requests.structures
Data structures that power Requests.
"""
from .compat import OrderedDict, Mapping, MutableMapping
from collections import OrderedDict
from .compat import Mapping, MutableMapping
class CaseInsensitiveDict(MutableMapping):

@ -19,6 +19,7 @@ import sys
import tempfile
import warnings
import zipfile
from collections import OrderedDict
from .__version__ import __version__
from . import certs
@ -26,7 +27,7 @@ from . import certs
from ._internal_utils import to_native_string
from .compat import parse_http_list as _parse_list_header
from .compat import (
quote, urlparse, bytes, str, OrderedDict, unquote, getproxies,
quote, urlparse, bytes, str, unquote, getproxies,
proxy_bypass, urlunparse, basestring, integer_types, is_py3,
proxy_bypass_environment, getproxies_environment, Mapping)
from .cookies import cookiejar_from_dict
@ -168,18 +169,24 @@ def super_len(o):
def get_netrc_auth(url, raise_errors=False):
"""Returns the Requests tuple auth for a given url from netrc."""
netrc_file = os.environ.get('NETRC')
if netrc_file is not None:
netrc_locations = (netrc_file,)
else:
netrc_locations = ('~/{}'.format(f) for f in NETRC_FILES)
try:
from netrc import netrc, NetrcParseError
netrc_path = None
for f in NETRC_FILES:
for f in netrc_locations:
try:
loc = os.path.expanduser('~/{}'.format(f))
loc = os.path.expanduser(f)
except KeyError:
# os.path.expanduser can fail when $HOME is undefined and
# getpwuid fails. See https://bugs.python.org/issue20164 &
# https://github.com/requests/requests/issues/1846
# https://github.com/psf/requests/issues/1846
return
if os.path.exists(loc):
@ -211,7 +218,7 @@ def get_netrc_auth(url, raise_errors=False):
if raise_errors:
raise
# AppEngine hackiness.
# App Engine hackiness.
except (ImportError, AttributeError):
pass
@ -266,6 +273,8 @@ def from_key_val_list(value):
>>> from_key_val_list([('key', 'val')])
OrderedDict([('key', 'val')])
>>> from_key_val_list('string')
Traceback (most recent call last):
...
ValueError: cannot encode objects that are not 2-tuples
>>> from_key_val_list({'key': 'val'})
OrderedDict([('key', 'val')])
@ -292,7 +301,9 @@ def to_key_val_list(value):
>>> to_key_val_list({'key': 'val'})
[('key', 'val')]
>>> to_key_val_list('string')
ValueError: cannot encode objects that are not 2-tuples.
Traceback (most recent call last):
...
ValueError: cannot encode objects that are not 2-tuples
:rtype: list
"""
@ -492,6 +503,10 @@ def get_encoding_from_headers(headers):
if 'text' in content_type:
return 'ISO-8859-1'
if 'application/json' in content_type:
# Assume UTF-8 based on RFC 4627: https://www.ietf.org/rfc/rfc4627.txt since the charset was unset
return 'utf-8'
def stream_decode_response_unicode(iterator, r):
"""Stream decodes a iterator."""

@ -0,0 +1,8 @@
from gevent import monkey
monkey.patch_socket()
monkey.patch_ssl()
from ._connection import Connection
__version__ = '0.0.7'

@ -0,0 +1,91 @@
import json
import gevent
from signalr.events import EventHook
from signalr.hubs import Hub
from signalr.transports import AutoTransport
class Connection:
protocol_version = '1.5'
def __init__(self, url, session):
self.url = url
self.__hubs = {}
self.qs = {}
self.__send_counter = -1
self.token = None
self.data = None
self.received = EventHook()
self.error = EventHook()
self.starting = EventHook()
self.stopping = EventHook()
self.__transport = AutoTransport(session, self)
self.__greenlet = None
self.started = False
def handle_error(**kwargs):
error = kwargs["E"] if "E" in kwargs else None
if error is None:
return
self.error.fire(error)
self.received += handle_error
self.starting += self.__set_data
def __set_data(self):
self.data = json.dumps([{'name': hub_name} for hub_name in self.__hubs])
def increment_send_counter(self):
self.__send_counter += 1
return self.__send_counter
def start(self):
self.starting.fire()
negotiate_data = self.__transport.negotiate()
self.token = negotiate_data['ConnectionToken']
listener = self.__transport.start()
def wrapped_listener():
try:
listener()
gevent.sleep()
except Exception as e:
gevent.kill(self.__greenlet)
self.started = False
self.__greenlet = gevent.spawn(wrapped_listener)
self.started = True
def wait(self, timeout=30):
gevent.joinall([self.__greenlet], timeout)
def send(self, data):
self.__transport.send(data)
def close(self):
gevent.kill(self.__greenlet)
self.__transport.close()
def register_hub(self, name):
if name not in self.__hubs:
if self.started:
raise RuntimeError(
'Cannot create new hub because connection is already started.')
self.__hubs[name] = Hub(name, self)
return self.__hubs[name]
def hub(self, name):
return self.__hubs[name]
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

@ -0,0 +1 @@
from ._events import EventHook

@ -0,0 +1,15 @@
class EventHook(object):
def __init__(self):
self._handlers = []
def __iadd__(self, handler):
self._handlers.append(handler)
return self
def __isub__(self, handler):
self._handlers.remove(handler)
return self
def fire(self, *args, **kwargs):
for handler in self._handlers:
handler(*args, **kwargs)

@ -0,0 +1 @@
from ._hub import Hub

@ -0,0 +1,55 @@
from signalr.events import EventHook
class Hub:
def __init__(self, name, connection):
self.name = name
self.server = HubServer(name, connection, self)
self.client = HubClient(name, connection)
self.error = EventHook()
class HubServer:
def __init__(self, name, connection, hub):
self.name = name
self.__connection = connection
self.__hub = hub
def invoke(self, method, *data):
self.__connection.send({
'H': self.name,
'M': method,
'A': data,
'I': self.__connection.increment_send_counter()
})
class HubClient(object):
def __init__(self, name, connection):
self.name = name
self.__handlers = {}
def handle(**kwargs):
messages = kwargs['M'] if 'M' in kwargs and len(kwargs['M']) > 0 else {}
for inner_data in messages:
hub = inner_data['H'] if 'H' in inner_data else ''
if hub.lower() == self.name.lower():
method = inner_data['name']
if method in self.__handlers:
self.__handlers[method].fire(inner_data)
connection.received += handle
def on(self, method, handler):
if method not in self.__handlers:
self.__handlers[method] = EventHook()
self.__handlers[method] += handler
def off(self, method, handler):
if method in self.__handlers:
self.__handlers[method] -= handler
class DictToObj:
def __init__(self, d):
self.__dict__ = d

@ -0,0 +1 @@
from ._auto_transport import AutoTransport

@ -0,0 +1,37 @@
from ._transport import Transport
from ._sse_transport import ServerSentEventsTransport
from ._ws_transport import WebSocketsTransport
class AutoTransport(Transport):
def __init__(self, session, connection):
Transport.__init__(self, session, connection)
self.__available_transports = [
WebSocketsTransport(session, connection),
ServerSentEventsTransport(session, connection)
]
self.__transport = None
def negotiate(self):
negotiate_data = Transport.negotiate(self)
self.__transport = self.__get_transport(negotiate_data)
return negotiate_data
def __get_transport(self, negotiate_data):
for transport in self.__available_transports:
if transport.accept(negotiate_data):
return transport
raise Exception('Cannot find suitable transport')
def start(self):
return self.__transport.start()
def send(self, data):
self.__transport.send(data)
def close(self):
self.__transport.close()
def _get_name(self):
return 'auto'

@ -0,0 +1,35 @@
import json
import sseclient
from ._transport import Transport
from requests.exceptions import ConnectionError
class ServerSentEventsTransport(Transport):
def __init__(self, session, connection):
Transport.__init__(self, session, connection)
self.__response = None
def _get_name(self):
return 'serverSentEvents'
def start(self):
self.__response = sseclient.SSEClient(self._get_url('connect'), session=self._session)
self._session.get(self._get_url('start'))
def _receive():
try:
for notification in self.__response:
if notification.data != 'initialized':
self._handle_notification(notification.data)
except ConnectionError:
raise ConnectionError
return _receive
def send(self, data):
response = self._session.post(self._get_url('send'), data={'data': json.dumps(data)})
parsed = json.loads(response.content)
self._connection.received.fire(**parsed)
def close(self):
self._session.get(self._get_url('abort'))

@ -0,0 +1,70 @@
from abc import abstractmethod
import json
import sys
if sys.version_info[0] < 3:
from urllib import quote_plus
else:
from urllib.parse import quote_plus
import gevent
class Transport:
def __init__(self, session, connection):
self._session = session
self._connection = connection
@abstractmethod
def _get_name(self):
pass
def negotiate(self):
url = self.__get_base_url(self._connection,
'negotiate',
connectionData=self._connection.data)
negotiate = self._session.get(url)
negotiate.raise_for_status()
return negotiate.json()
@abstractmethod
def start(self):
pass
@abstractmethod
def send(self, data):
pass
@abstractmethod
def close(self):
pass
def accept(self, negotiate_data):
return True
def _handle_notification(self, message):
if len(message) > 0:
data = json.loads(message)
self._connection.received.fire(**data)
gevent.sleep()
def _get_url(self, action, **kwargs):
args = kwargs.copy()
args['transport'] = self._get_name()
args['connectionToken'] = self._connection.token
args['connectionData'] = self._connection.data
return self.__get_base_url(self._connection, action, **args)
@staticmethod
def __get_base_url(connection, action, **kwargs):
args = kwargs.copy()
args.update(connection.qs)
args['clientProtocol'] = connection.protocol_version
query = '&'.join(['{key}={value}'.format(key=key, value=quote_plus(args[key])) for key in args])
return '{url}/{action}?{query}'.format(url=connection.url,
action=action,
query=query)

@ -0,0 +1,77 @@
import json
import sys
import gevent
if sys.version_info[0] < 3:
from urlparse import urlparse, urlunparse
else:
from urllib.parse import urlparse, urlunparse
from websocket import create_connection
from ._transport import Transport
class WebSocketsTransport(Transport):
def __init__(self, session, connection):
Transport.__init__(self, session, connection)
self.ws = None
self.__requests = {}
def _get_name(self):
return 'webSockets'
@staticmethod
def __get_ws_url_from(url):
parsed = urlparse(url)
scheme = 'wss' if parsed.scheme == 'https' else 'ws'
url_data = (scheme, parsed.netloc, parsed.path, parsed.params, parsed.query, parsed.fragment)
return urlunparse(url_data)
def start(self):
ws_url = self.__get_ws_url_from(self._get_url('connect'))
self.ws = create_connection(ws_url,
header=self.__get_headers(),
cookie=self.__get_cookie_str(),
enable_multithread=True)
self._session.get(self._get_url('start'))
def _receive():
try:
for notification in self.ws:
self._handle_notification(notification)
except ConnectionError:
raise ConnectionError
return _receive
def send(self, data):
self.ws.send(json.dumps(data))
gevent.sleep()
def close(self):
self.ws.close()
def accept(self, negotiate_data):
return bool(negotiate_data['TryWebSockets'])
class HeadersLoader(object):
def __init__(self, headers):
self.headers = headers
def __get_headers(self):
headers = self._session.headers
loader = WebSocketsTransport.HeadersLoader(headers)
if self._session.auth:
self._session.auth(loader)
return ['%s: %s' % (name, headers[name]) for name in headers]
def __get_cookie_str(self):
return '; '.join([
'%s=%s' % (name, value)
for name, value in self._session.cookies.items()
])

@ -0,0 +1,99 @@
import logging
import urllib.parse as parse
class Helpers:
@staticmethod
def configure_logger(level=logging.INFO, handler=None):
logger = Helpers.get_logger()
if handler is None:
handler = logging.StreamHandler()
handler.setFormatter(
logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
handler.setLevel(level)
logger.addHandler(handler)
logger.setLevel(level)
@staticmethod
def get_logger():
return logging.getLogger("SignalRCoreClient")
@staticmethod
def has_querystring(url):
return "?" in url
@staticmethod
def split_querystring(url):
parts = url.split("?")
return parts[0], parts[1]
@staticmethod
def replace_scheme(
url,
root_scheme,
source,
secure_source,
destination,
secure_destination):
url_parts = parse.urlsplit(url)
if root_scheme not in url_parts.scheme:
if url_parts.scheme == secure_source:
url_parts = url_parts._replace(scheme=secure_destination)
if url_parts.scheme == source:
url_parts = url_parts._replace(scheme=destination)
return parse.urlunsplit(url_parts)
@staticmethod
def websocket_to_http(url):
return Helpers.replace_scheme(
url,
"http",
"ws",
"wss",
"http",
"https")
@staticmethod
def http_to_websocket(url):
return Helpers.replace_scheme(
url,
"ws",
"http",
"https",
"ws",
"wss"
)
@staticmethod
def get_negotiate_url(url):
querystring = ""
if Helpers.has_querystring(url):
url, querystring = Helpers.split_querystring(url)
url_parts = parse.urlsplit(Helpers.websocket_to_http(url))
negotiate_suffix = "negotiate"\
if url_parts.path.endswith('/')\
else "/negotiate"
url_parts = url_parts._replace(path=url_parts.path + negotiate_suffix)
return parse.urlunsplit(url_parts) \
if querystring == "" else\
parse.urlunsplit(url_parts) + "?" + querystring
@staticmethod
def encode_connection_id(url, id):
url_parts = parse.urlsplit(url)
query_string_parts = parse.parse_qs(url_parts.query)
query_string_parts["id"] = id
url_parts = url_parts._replace(
query=parse.urlencode(
query_string_parts,
doseq=True))
return Helpers.http_to_websocket(parse.urlunsplit(url_parts))

@ -0,0 +1,22 @@
from .base_hub_connection import BaseHubConnection
from ..helpers import Helpers
class AuthHubConnection(BaseHubConnection):
def __init__(self, auth_function, headers={}, **kwargs):
self.headers = headers
self.auth_function = auth_function
super(AuthHubConnection, self).__init__(**kwargs)
def start(self):
try:
Helpers.get_logger().debug("Starting connection ...")
self.token = self.auth_function()
Helpers.get_logger()\
.debug("auth function result {0}".format(self.token))
self.headers["Authorization"] = "Bearer " + self.token
return super(AuthHubConnection, self).start()
except Exception as ex:
Helpers.get_logger().warning(self.__class__.__name__)
Helpers.get_logger().warning(str(ex))
raise ex

@ -0,0 +1,238 @@
import websocket
import threading
import requests
import traceback
import uuid
import time
import ssl
from typing import Callable
from signalrcore.messages.message_type import MessageType
from signalrcore.messages.stream_invocation_message\
import StreamInvocationMessage
from signalrcore.messages.ping_message import PingMessage
from .errors import UnAuthorizedHubError, HubError, HubConnectionError
from signalrcore.helpers import Helpers
from .handlers import StreamHandler, InvocationHandler
from ..protocol.messagepack_protocol import MessagePackHubProtocol
from ..transport.websockets.websocket_transport import WebsocketTransport
from ..helpers import Helpers
from ..subject import Subject
from ..messages.invocation_message import InvocationMessage
class BaseHubConnection(object):
def __init__(
self,
url,
protocol,
headers={},
**kwargs):
self.headers = headers
self.logger = Helpers.get_logger()
self.handlers = []
self.stream_handlers = []
self._on_error = lambda error: self.logger.info(
"on_error not defined {0}".format(error))
self.transport = WebsocketTransport(
url=url,
protocol=protocol,
headers=headers,
on_message=self.on_message,
**kwargs)
def start(self):
self.logger.debug("Connection started")
return self.transport.start()
def stop(self):
self.logger.debug("Connection stop")
return self.transport.stop()
def on_close(self, callback):
"""Configures on_close connection callback.
It will be raised on connection closed event
connection.on_close(lambda: print("connection closed"))
Args:
callback (function): function without params
"""
self.transport.on_close_callback(callback)
def on_open(self, callback):
"""Configures on_open connection callback.
It will be raised on connection open event
connection.on_open(lambda: print(
"connection opened "))
Args:
callback (function): funciton without params
"""
self.transport.on_open_callback(callback)
def on_error(self, callback):
"""Configures on_error connection callback. It will be raised
if any hub method throws an exception.
connection.on_error(lambda data:
print(f"An exception was thrown closed{data.error}"))
Args:
callback (function): function with one parameter.
A CompletionMessage object.
"""
self._on_error = callback
def on(self, event, callback_function: Callable):
"""Register a callback on the specified event
Args:
event (string): Event name
callback_function (Function): callback function,
arguments will be binded
"""
self.logger.debug("Handler registered started {0}".format(event))
self.handlers.append((event, callback_function))
def send(self, method, arguments, on_invocation=None):
"""Sends a message
Args:
method (string): Method name
arguments (list|Subject): Method parameters
on_invocation (function, optional): On invocation send callback
will be raised on send server function ends. Defaults to None.
Raises:
HubConnectionError: If hub is not ready to send
TypeError: If arguments are invalid list or Subject
"""
if not self.transport.is_running():
raise HubConnectionError(
"Hub is not running you cand send messages")
if type(arguments) is not list and type(arguments) is not Subject:
raise TypeError("Arguments of a message must be a list or subject")
if type(arguments) is list:
message = InvocationMessage(
str(uuid.uuid4()),
method,
arguments,
headers=self.headers)
if on_invocation:
self.stream_handlers.append(
InvocationHandler(
message.invocation_id,
on_invocation))
self.transport.send(message)
if type(arguments) is Subject:
arguments.connection = self
arguments.target = method
arguments.start()
def on_message(self, messages):
for message in messages:
if message.type == MessageType.invocation_binding_failure:
self.logger.error(message)
self._on_error(message)
continue
if message.type == MessageType.ping:
continue
if message.type == MessageType.invocation:
fired_handlers = list(
filter(
lambda h: h[0] == message.target,
self.handlers))
if len(fired_handlers) == 0:
self.logger.warning(
"event '{0}' hasn't fire any handler".format(
message.target))
for _, handler in fired_handlers:
handler(message.arguments)
if message.type == MessageType.close:
self.logger.info("Close message received from server")
self.stop()
return
if message.type == MessageType.completion:
if message.error is not None and len(message.error) > 0:
self._on_error(message)
# Send callbacks
fired_handlers = list(
filter(
lambda h: h.invocation_id == message.invocation_id,
self.stream_handlers))
# Stream callbacks
for handler in fired_handlers:
handler.complete_callback(message)
# unregister handler
self.stream_handlers = list(
filter(
lambda h: h.invocation_id != message.invocation_id,
self.stream_handlers))
if message.type == MessageType.stream_item:
fired_handlers = list(
filter(
lambda h: h.invocation_id == message.invocation_id,
self.stream_handlers))
if len(fired_handlers) == 0:
self.logger.warning(
"id '{0}' hasn't fire any stream handler".format(
message.invocation_id))
for handler in fired_handlers:
handler.next_callback(message.item)
if message.type == MessageType.stream_invocation:
pass
if message.type == MessageType.cancel_invocation:
fired_handlers = list(
filter(
lambda h: h.invocation_id == message.invocation_id,
self.stream_handlers))
if len(fired_handlers) == 0:
self.logger.warning(
"id '{0}' hasn't fire any stream handler".format(
message.invocation_id))
for handler in fired_handlers:
handler.error_callback(message)
# unregister handler
self.stream_handlers = list(
filter(
lambda h: h.invocation_id != message.invocation_id,
self.stream_handlers))
def stream(self, event, event_params):
"""Starts server streaming
connection.stream(
"Counter",
[len(self.items), 500])\
.subscribe({
"next": self.on_next,
"complete": self.on_complete,
"error": self.on_error
})
Args:
event (string): Method Name
event_params (list): Method parameters
Returns:
[StreamHandler]: stream handler
"""
invocation_id = str(uuid.uuid4())
stream_obj = StreamHandler(event, invocation_id)
self.stream_handlers.append(stream_obj)
self.transport.send(
StreamInvocationMessage(
invocation_id,
event,
event_params,
headers=self.headers))
return stream_obj

@ -0,0 +1,10 @@
class HubError(OSError):
pass
class UnAuthorizedHubError(HubError):
pass
class HubConnectionError(ValueError):
"""Hub connection error
"""
pass

@ -0,0 +1,51 @@
import logging
from typing import Callable
from ..helpers import Helpers
class StreamHandler(object):
def __init__(self, event: str, invocation_id: str):
self.event = event
self.invocation_id = invocation_id
self.logger = Helpers.get_logger()
self.next_callback =\
lambda _: self.logger.warning(
"next stream handler fired, no callback configured")
self.complete_callback =\
lambda _: self.logger.warning(
"next complete handler fired, no callback configured")
self.error_callback =\
lambda _: self.logger.warning(
"next error handler fired, no callback configured")
def subscribe(self, subscribe_callbacks: dict):
error =\
" subscribe object must be a dict like {0}"\
.format({
"next": None,
"complete": None,
"error": None
})
if subscribe_callbacks is None or\
type(subscribe_callbacks) is not dict:
raise TypeError(error)
if "next" not in subscribe_callbacks or\
"complete" not in subscribe_callbacks \
or "error" not in subscribe_callbacks:
raise KeyError(error)
if not callable(subscribe_callbacks["next"])\
or not callable(subscribe_callbacks["next"]) \
or not callable(subscribe_callbacks["next"]):
raise ValueError("Suscribe callbacks must be functions")
self.next_callback = subscribe_callbacks["next"]
self.complete_callback = subscribe_callbacks["complete"]
self.error_callback = subscribe_callbacks["error"]
class InvocationHandler(object):
def __init__(self, invocation_id: str, complete_callback: Callable):
self.invocation_id = invocation_id
self.complete_callback = complete_callback

@ -0,0 +1,245 @@
import uuid
from .hub.base_hub_connection import BaseHubConnection
from .hub.auth_hub_connection import AuthHubConnection
from .transport.websockets.reconnection import \
IntervalReconnectionHandler, RawReconnectionHandler, ReconnectionType
from .helpers import Helpers
from .messages.invocation_message import InvocationMessage
from .protocol.json_hub_protocol import JsonHubProtocol
from .subject import Subject
class HubConnectionBuilder(object):
"""
Hub connection class, manages handshake and messaging
Args:
hub_url: SignalR core url
Raises:
HubConnectionError: Raises an Exception if url is empty or None
"""
def __init__(self):
self.hub_url = None
self.hub = None
self.options = {
"access_token_factory": None
}
self.token = None
self.headers = None
self.negotiate_headers = None
self.has_auth_configured = None
self.protocol = None
self.reconnection_handler = None
self.keep_alive_interval = None
self.verify_ssl = True
self.enable_trace = False # socket trace
self.skip_negotiation = False # By default do not skip negotiation
self.running = False
def with_url(
self,
hub_url: str,
options: dict = None):
"""Configure the hub url and options like negotiation and auth function
def login(self):
response = requests.post(
self.login_url,
json={
"username": self.email,
"password": self.password
},verify=False)
return response.json()["token"]
self.connection = HubConnectionBuilder()\
.with_url(self.server_url,
options={
"verify_ssl": False,
"access_token_factory": self.login,
"headers": {
"mycustomheader": "mycustomheadervalue"
}
})\
.configure_logging(logging.ERROR)\
.with_automatic_reconnect({
"type": "raw",
"keep_alive_interval": 10,
"reconnect_interval": 5,
"max_attempts": 5
}).build()
Args:
hub_url (string): Hub URL
options ([dict], optional): [description]. Defaults to None.
Raises:
ValueError: If url is invalid
TypeError: If options are not a dict or auth function
is not callable
Returns:
[HubConnectionBuilder]: configured connection
"""
if hub_url is None or hub_url.strip() == "":
raise ValueError("hub_url must be a valid url.")
if options is not None and type(options) != dict:
raise TypeError(
"options must be a dict {0}.".format(self.options))
if options is not None \
and "access_token_factory" in options.keys()\
and not callable(options["access_token_factory"]):
raise TypeError(
"access_token_factory must be a function without params")
if options is not None:
self.has_auth_configured = \
"access_token_factory" in options.keys()\
and callable(options["access_token_factory"])
self.skip_negotiation = "skip_negotiation" in options.keys()\
and options["skip_negotiation"]
self.hub_url = hub_url
self.hub = None
self.options = self.options if options is None else options
return self
def configure_logging(
self, logging_level, socket_trace=False, handler=None):
"""Configures signalr logging
Args:
logging_level ([type]): logging.INFO | logging.DEBUG ...
from python logging class
socket_trace (bool, optional): Enables socket package trace.
Defaults to False.
handler ([type], optional): Custom logging handler.
Defaults to None.
Returns:
[HubConnectionBuilder]: Instance hub with logging configured
"""
Helpers.configure_logger(logging_level, handler)
self.enable_trace = socket_trace
return self
def with_hub_protocol(self, protocol):
"""Changes transport protocol
from signalrcore.protocol.messagepack_protocol\
import MessagePackHubProtocol
HubConnectionBuilder()\
.with_url(self.server_url, options={"verify_ssl":False})\
...
.with_hub_protocol(MessagePackHubProtocol())\
...
.build()
Args:
protocol (JsonHubProtocol|MessagePackHubProtocol):
protocol instance
Returns:
HubConnectionBuilder: instance configured
"""
self.protocol = protocol
return self
def build(self):
"""Configures the connection hub
Raises:
TypeError: Checks parameters an raises TypeError
if one of them is wrong
Returns:
[HubConnectionBuilder]: [self object for fluent interface purposes]
"""
if self.protocol is None:
self.protocol = JsonHubProtocol()
self.headers = {}
if "headers" in self.options.keys()\
and type(self.options["headers"]) is dict:
self.headers = self.options["headers"]
if self.has_auth_configured:
auth_function = self.options["access_token_factory"]
if auth_function is None or not callable(auth_function):
raise TypeError(
"access_token_factory is not function")
if "verify_ssl" in self.options.keys()\
and type(self.options["verify_ssl"]) is bool:
self.verify_ssl = self.options["verify_ssl"]
return AuthHubConnection(
headers=self.headers,
auth_function=auth_function,
url=self.hub_url,
protocol=self.protocol,
keep_alive_interval=self.keep_alive_interval,
reconnection_handler=self.reconnection_handler,
verify_ssl=self.verify_ssl,
skip_negotiation=self.skip_negotiation,
enable_trace=self.enable_trace)\
if self.has_auth_configured else\
BaseHubConnection(
url=self.hub_url,
protocol=self.protocol,
keep_alive_interval=self.keep_alive_interval,
reconnection_handler=self.reconnection_handler,
headers=self.headers,
verify_ssl=self.verify_ssl,
skip_negotiation=self.skip_negotiation,
enable_trace=self.enable_trace)
def with_automatic_reconnect(self, data: dict):
"""Configures automatic reconnection
https://devblogs.microsoft.com/aspnet/asp-net-core-updates-in-net-core-3-0-preview-4/
hub = HubConnectionBuilder()\
.with_url(self.server_url, options={"verify_ssl":False})\
.configure_logging(logging.ERROR)\
.with_automatic_reconnect({
"type": "raw",
"keep_alive_interval": 10,
"reconnect_interval": 5,
"max_attempts": 5
})\
.build()
Args:
data (dict): [dict with autmatic reconnection parameters]
Returns:
[HubConnectionBuilder]: [self object for fluent interface purposes]
"""
reconnect_type = data.get("type", "raw")
# Infinite reconnect attempts
max_attempts = data.get("max_attempts", None)
# 5 sec interval
reconnect_interval = data.get("reconnect_interval", 5)
keep_alive_interval = data.get("keep_alive_interval", 15)
intervals = data.get("intervals", []) # Reconnection intervals
self.keep_alive_interval = keep_alive_interval
reconnection_type = ReconnectionType[reconnect_type]
if reconnection_type == ReconnectionType.raw:
self.reconnection_handler = RawReconnectionHandler(
reconnect_interval,
max_attempts
)
if reconnection_type == ReconnectionType.interval:
self.reconnection_handler = IntervalReconnectionHandler(
intervals
)
return self

@ -0,0 +1,15 @@
from .message_type import MessageType
class BaseMessage(object):
def __init__(self, message_type, **kwargs):
self.type = MessageType(message_type)
class BaseHeadersMessage(BaseMessage):
"""
All messages expct ping can carry aditional headers
"""
def __init__(self, message_type, headers={}, **kwargs):
super(BaseHeadersMessage, self).__init__(message_type)
self.headers = headers

@ -0,0 +1,24 @@
from .base_message import BaseHeadersMessage
"""
A `CancelInvocation` message is a JSON object with the following properties
* `type` - A `Number` with the literal value `5`,
indicating that this message is a `CancelInvocation`.
* `invocationId` - A `String` encoding the `Invocation ID` for a message.
Example
```json
{
"type": 5,
"invocationId": "123"
}
"""
class CancelInvocationMessage(BaseHeadersMessage):
def __init__(
self,
invocation_id,
**kwargs):
super(CancelInvocationMessage, self).__init__(5, **kwargs)
self.invocation_id = invocation_id

@ -0,0 +1,32 @@
from .base_message import BaseHeadersMessage
"""
A `Close` message is a JSON object with the following properties
* `type` - A `Number` with the literal value `7`,
indicating that this message is a `Close`.
* `error` - An optional `String` encoding the error message.
Example - A `Close` message without an error
```json
{
"type": 7
}
```
Example - A `Close` message with an error
```json
{
"type": 7,
"error": "Connection closed because of an error!"
}
```
"""
class CloseMessage(BaseHeadersMessage):
def __init__(
self,
error,
**kwargs):
super(CloseMessage, self).__init__(7, **kwargs)
self.error = error

@ -0,0 +1,77 @@
from .base_message import BaseHeadersMessage
"""
A `Completion` message is a JSON object with the following properties
* `type` - A `Number` with the literal value `3`,
indicating that this message is a `Completion`.
* `invocationId` - A `String` encoding the `Invocation ID` for a message.
* `result` - A `Token` encoding the result value
(see "JSON Payload Encoding" for details).
This field is **ignored** if `error` is present.
* `error` - A `String` encoding the error message.
It is a protocol error to include both a `result` and an `error` property
in the `Completion` message. A conforming endpoint may immediately
terminate the connection upon receiving such a message.
Example - A `Completion` message with no result or error
```json
{
"type": 3,
"invocationId": "123"
}
```
Example - A `Completion` message with a result
```json
{
"type": 3,
"invocationId": "123",
"result": 42
}
```
Example - A `Completion` message with an error
```json
{
"type": 3,
"invocationId": "123",
"error": "It didn't work!"
}
```
Example - The following `Completion` message is a protocol error
because it has both of `result` and `error`
```json
{
"type": 3,
"invocationId": "123",
"result": 42,
"error": "It didn't work!"
}
```
"""
class CompletionClientStreamMessage(BaseHeadersMessage):
def __init__(
self, invocation_id, **kwargs):
super(CompletionClientStreamMessage, self).__init__(3, **kwargs)
self.invocation_id = invocation_id
class CompletionMessage(BaseHeadersMessage):
def __init__(
self,
invocation_id,
result,
error,
**kwargs):
super(CompletionMessage, self).__init__(3, **kwargs)
self.invocation_id = invocation_id
self.result = result
self.error = error

@ -0,0 +1,5 @@
class HandshakeRequestMessage(object):
def __init__(self, protocol, version):
self.protocol = protocol
self.version = version

@ -0,0 +1,4 @@
class HandshakeResponseMessage(object):
def __init__(self, error):
self.error = error

@ -0,0 +1,78 @@
from .base_message import BaseHeadersMessage
"""
An `Invocation` message is a JSON object with the following properties:
* `type` - A `Number` with the literal value 1, indicating that this message
is an Invocation.
* `invocationId` - An optional `String` encoding the `Invocation ID`
for a message.
* `target` - A `String` encoding the `Target` name, as expected by the Callee's
Binder
* `arguments` - An `Array` containing arguments to apply to the method
referred to in Target. This is a sequence of JSON `Token`s,
encoded as indicated below in the "JSON Payload Encoding" section
Example:
```json
{
"type": 1,
"invocationId": "123",
"target": "Send",
"arguments": [
42,
"Test Message"
]
}
```
Example (Non-Blocking):
```json
{
"type": 1,
"target": "Send",
"arguments": [
42,
"Test Message"
]
}
```
"""
class InvocationMessage(BaseHeadersMessage):
def __init__(
self,
invocation_id,
target,
arguments, **kwargs):
super(InvocationMessage, self).__init__(1, **kwargs)
self.invocation_id = invocation_id
self.target = target
self.arguments = arguments
def __repr__(self):
repr_str =\
"InvocationMessage: invocation_id {0}, target {1}, arguments {2}"
return repr_str.format(self.invocation_id, self.target, self.arguments)
class InvocationClientStreamMessage(BaseHeadersMessage):
def __init__(
self,
stream_ids,
target,
arguments,
**kwargs):
super(InvocationClientStreamMessage, self).__init__(1, **kwargs)
self.target = target
self.arguments = arguments
self.stream_ids = stream_ids
def __repr__(self):
repr_str =\
"InvocationMessage: stream_ids {0}, target {1}, arguments {2}"
return repr_str.format(
self.stream_ids, self.target, self.arguments)

@ -0,0 +1,12 @@
from enum import Enum
class MessageType(Enum):
invocation = 1
stream_item = 2
completion = 3
stream_invocation = 4
cancel_invocation = 5
ping = 6
close = 7
invocation_binding_failure = -1

@ -0,0 +1,20 @@
from .base_message import BaseMessage
"""
A `Ping` message is a JSON object with the following properties:
* `type` - A `Number` with the literal value `6`,
indicating that this message is a `Ping`.
Example
```json
{
"type": 6
}
```
"""
class PingMessage(BaseMessage):
def __init__(
self, **kwargs):
super(PingMessage, self).__init__(6, **kwargs)

@ -0,0 +1,42 @@
from .base_message import BaseHeadersMessage
"""
A `StreamInvocation` message is a JSON object with the following properties:
* `type` - A `Number` with the literal value 4, indicating that
this message is a StreamInvocation.
* `invocationId` - A `String` encoding the `Invocation ID` for a message.
* `target` - A `String` encoding the `Target` name, as expected
by the Callee's Binder.
* `arguments` - An `Array` containing arguments to apply to
the method referred to in Target. This is a sequence of JSON
`Token`s, encoded as indicated below in the
"JSON Payload Encoding" section.
Example:
```json
{
"type": 4,
"invocationId": "123",
"target": "Send",
"arguments": [
42,
"Test Message"
]
}
```
"""
class StreamInvocationMessage(BaseHeadersMessage):
def __init__(
self,
invocation_id,
target,
arguments,
**kwargs):
super(StreamInvocationMessage, self).__init__(4, **kwargs)
self.invocation_id = invocation_id
self.target = target
self.arguments = arguments
self.stream_ids = []

@ -0,0 +1,31 @@
from .base_message import BaseHeadersMessage
"""
A `StreamItem` message is a JSON object with the following properties:
* `type` - A `Number` with the literal value 2, indicating
that this message is a `StreamItem`.
* `invocationId` - A `String` encoding the `Invocation ID` for a message.
* `item` - A `Token` encoding the stream item
(see "JSON Payload Encoding" for details).
Example
```json
{
"type": 2,
"invocationId": "123",
"item": 42
}
```
"""
class StreamItemMessage(BaseHeadersMessage):
def __init__(
self,
invocation_id,
item,
**kwargs):
super(StreamItemMessage, self).__init__(2, **kwargs)
self.invocation_id = invocation_id
self.item = item

@ -0,0 +1,60 @@
import json
from ..messages.handshake.request import HandshakeRequestMessage
from ..messages.handshake.response import HandshakeResponseMessage
from ..messages.invocation_message import InvocationMessage # 1
from ..messages.stream_item_message import StreamItemMessage # 2
from ..messages.completion_message import CompletionMessage # 3
from ..messages.stream_invocation_message import StreamInvocationMessage # 4
from ..messages.cancel_invocation_message import CancelInvocationMessage # 5
from ..messages.ping_message import PingMessage # 6
from ..messages.close_message import CloseMessage # 7
from ..messages.message_type import MessageType
from ..helpers import Helpers
class BaseHubProtocol(object):
def __init__(self, protocol, version, transfer_format, record_separator):
self.protocol = protocol
self.version = version
self.transfer_format = transfer_format
self.record_separator = record_separator
@staticmethod
def get_message(dict_message):
message_type = MessageType.close\
if not "type" in dict_message.keys() else MessageType(dict_message["type"])
dict_message["invocation_id"] = dict_message.get("invocationId", None)
dict_message["headers"] = dict_message.get("headers", {})
dict_message["error"] = dict_message.get("error", None)
dict_message["result"] = dict_message.get("result", None)
if message_type is MessageType.invocation:
return InvocationMessage(**dict_message)
if message_type is MessageType.stream_item:
return StreamItemMessage(**dict_message)
if message_type is MessageType.completion:
return CompletionMessage(**dict_message)
if message_type is MessageType.stream_invocation:
return StreamInvocationMessage(**dict_message)
if message_type is MessageType.cancel_invocation:
return CancelInvocationMessage(**dict_message)
if message_type is MessageType.ping:
return PingMessage()
if message_type is MessageType.close:
return CloseMessage(**dict_message)
def decode_handshake(self, raw_message: str) -> HandshakeResponseMessage:
messages = raw_message.split(self.record_separator)
messages = list(filter(lambda x: x != "", messages))
data = json.loads(messages[0])
idx = raw_message.index(self.record_separator)
return HandshakeResponseMessage(data.get("error", None)), self.parse_messages(raw_message[idx + 1 :]) if len(messages) > 1 else []
def handshake_message(self) -> HandshakeRequestMessage:
return HandshakeRequestMessage(self.protocol, self.version)
def parse_messages(self, raw_message: str):
raise ValueError("Protocol must implement this method")
def write_message(self, hub_message):
raise ValueError("Protocol must implement this method")

@ -0,0 +1,50 @@
import json
from .base_hub_protocol import BaseHubProtocol
from ..messages.message_type import MessageType
from json import JSONEncoder
from signalrcore.helpers import Helpers
class MyEncoder(JSONEncoder):
# https://github.com/PyCQA/pylint/issues/414
def default(self, o):
if type(o) is MessageType:
return o.value
data = o.__dict__
if "invocation_id" in data:
data["invocationId"] = data["invocation_id"]
del data["invocation_id"]
if "stream_ids" in data:
data["streamIds"] = data["stream_ids"]
del data["stream_ids"]
return data
class JsonHubProtocol(BaseHubProtocol):
def __init__(self):
super(JsonHubProtocol, self).__init__("json", 1, "Text", chr(0x1E))
self.encoder = MyEncoder()
def parse_messages(self, raw):
Helpers.get_logger().debug("Raw message incomming: ")
Helpers.get_logger().debug(raw)
raw_messages = [
record.replace(self.record_separator, "")
for record in raw.split(self.record_separator)
if record is not None and record != ""
and record != self.record_separator
]
result = []
for raw_message in raw_messages:
dict_message = json.loads(raw_message)
if len(dict_message.keys()) > 0:
result.append(self.get_message(dict_message))
return result
def encode(self, message):
Helpers.get_logger()\
.debug(self.encoder.encode(message) + self.record_separator)
return self.encoder.encode(message) + self.record_separator

@ -0,0 +1,169 @@
import json
import msgpack
from .base_hub_protocol import BaseHubProtocol
from ..messages.handshake.request import HandshakeRequestMessage
from ..messages.handshake.response import HandshakeResponseMessage
from ..messages.invocation_message\
import InvocationMessage, InvocationClientStreamMessage # 1
from ..messages.stream_item_message import StreamItemMessage # 2
from ..messages.completion_message import CompletionMessage # 3
from ..messages.stream_invocation_message import StreamInvocationMessage # 4
from ..messages.cancel_invocation_message import CancelInvocationMessage # 5
from ..messages.ping_message import PingMessage # 6
from ..messages.close_message import CloseMessage # 7
from ..helpers import Helpers
class MessagePackHubProtocol(BaseHubProtocol):
_priority = [
"type",
"headers",
"invocation_id",
"target",
"arguments",
"item",
"result_kind",
"result",
"stream_ids"
]
def __init__(self):
super(MessagePackHubProtocol, self).__init__(
"messagepack", 1, "Text", chr(0x1E))
self.logger = Helpers.get_logger()
def parse_messages(self, raw):
try:
messages = []
offset = 0
while offset < len(raw):
length = msgpack.unpackb(raw[offset: offset + 1])
values = msgpack.unpackb(raw[offset + 1: offset + length + 1])
offset = offset + length + 1
message = self._decode_message(values)
messages.append(message)
except Exception as ex:
Helpers.get_logger().error("Parse messages Error {0}".format(ex))
Helpers.get_logger().error("raw msg '{0}'".format(raw))
return messages
def decode_handshake(self, raw_message):
try:
has_various_messages = 0x1E in raw_message
handshake_data = raw_message[0: raw_message.index(0x1E)] if has_various_messages else raw_message
messages = self.parse_messages(raw_message[raw_message.index(0x1E) + 1:]) if has_various_messages else []
data = json.loads(handshake_data)
return HandshakeResponseMessage(data.get("error", None)), messages
except Exception as ex:
Helpers.get_logger().error(raw_message)
Helpers.get_logger().error(ex)
raise ex
def encode(self, message):
if type(message) is HandshakeRequestMessage:
content = json.dumps(message.__dict__)
return content + self.record_separator
msg = self._encode_message(message)
encoded_message = msgpack.packb(msg)
varint_length = self._to_varint(len(encoded_message))
return varint_length + encoded_message
def _encode_message(self, message):
result = []
# sort attributes
for attribute in self._priority:
if hasattr(message, attribute):
if (attribute == "type"):
result.append(getattr(message, attribute).value)
else:
result.append(getattr(message, attribute))
return result
def _decode_message(self, raw):
# {} {"error"}
# [1, Headers, InvocationId, Target, [Arguments], [StreamIds]]
# [2, Headers, InvocationId, Item]
# [3, Headers, InvocationId, ResultKind, Result]
# [4, Headers, InvocationId, Target, [Arguments], [StreamIds]]
# [5, Headers, InvocationId]
# [6]
# [7, Error, AllowReconnect?]
if raw[0] == 1: # InvocationMessage
if len(raw[5]) > 0:
return InvocationClientStreamMessage(
headers=raw[1],
stream_ids=raw[5],
target=raw[3],
arguments=raw[4])
else:
return InvocationMessage(
headers=raw[1],
invocation_id=raw[2],
target=raw[3],
arguments=raw[4])
elif raw[0] == 2: # StreamItemMessage
return StreamItemMessage(
headers=raw[1],
invocation_id=raw[2],
item=raw[3])
elif raw[0] == 3: # CompletionMessage
result_kind = raw[3]
if result_kind == 1:
return CompletionMessage(
headers=raw[1],
invocation_id=raw[2],
result=None,
error=raw[4])
elif result_kind == 2:
return CompletionMessage(
headers=raw[1], invocation_id=raw[2],
result=None, error=None)
elif result_kind == 3:
return CompletionMessage(
headers=raw[1], invocation_id=raw[2],
result=raw[4], error=None)
else:
raise Exception("Unknown result kind.")
elif raw[0] == 4: # StreamInvocationMessage
return StreamInvocationMessage(
headers=raw[1], invocation_id=raw[2],
target=raw[3], arguments=raw[4]) # stream_id missing?
elif raw[0] == 5: # CancelInvocationMessage
return CancelInvocationMessage(
headers=raw[1], invocation_id=raw[2])
elif raw[0] == 6: # PingMessageEncoding
return PingMessage()
elif raw[0] == 7: # CloseMessageEncoding
return CloseMessage(error=raw[1]) # AllowReconnect is missing
print(".......................................")
print(raw)
print("---------------------------------------")
raise Exception("Unknown message type.")
def _to_varint(self, value):
buffer = b''
while True:
byte = value & 0x7f
value >>= 7
if value:
buffer += bytes((byte | 0x80, ))
else:
buffer += bytes((byte, ))
break
return buffer

@ -0,0 +1,68 @@
import uuid
import threading
from typing import Any
from .messages.invocation_message import InvocationClientStreamMessage
from .messages.stream_item_message import StreamItemMessage
from .messages.completion_message import CompletionClientStreamMessage
class Subject(object):
"""Client to server streaming
https://docs.microsoft.com/en-gb/aspnet/core/signalr/streaming?view=aspnetcore-5.0#client-to-server-streaming
items = list(range(0,10))
subject = Subject()
connection.send("UploadStream", subject)
while(len(self.items) > 0):
subject.next(str(self.items.pop()))
subject.complete()
"""
def __init__(self):
self.connection = None
self.target = None
self.invocation_id = str(uuid.uuid4())
self.lock = threading.RLock()
def check(self):
"""Ensures that invocation streaming object is correct
Raises:
ValueError: if object is not valid, exception will be raised
"""
if self.connection is None\
or self.target is None\
or self.invocation_id is None:
raise ValueError(
"subject must be passed as an agument to a send function. "
+ "hub_connection.send([method],[subject]")
def next(self, item: Any):
"""Send next item to the server
Args:
item (any): Item that will be streamed
"""
self.check()
with self.lock:
self.connection.transport.send(StreamItemMessage(
self.invocation_id,
item))
def start(self):
"""Starts streaming
"""
self.check()
with self.lock:
self.connection.transport.send(
InvocationClientStreamMessage(
[self.invocation_id],
self.target,
[]))
def complete(self):
"""Finish streaming
"""
self.check()
with self.lock:
self.connection.transport.send(CompletionClientStreamMessage(
self.invocation_id))

@ -0,0 +1,29 @@
from ..protocol.json_hub_protocol import JsonHubProtocol
from ..helpers import Helpers
class BaseTransport(object):
def __init__(self, protocol=JsonHubProtocol(), on_message=None):
self.protocol = protocol
self._on_message= on_message
self.logger = Helpers.get_logger()
self._on_open = lambda: self.logger.info("on_connect not defined")
self._on_close = lambda: self.logger.info(
"on_disconnect not defined")
def on_open_callback(self, callback):
self._on_open = callback
def on_close_callback(self, callback):
self._on_close = callback
def start(self): # pragma: no cover
raise NotImplementedError()
def stop(self): # pragma: no cover
raise NotImplementedError()
def is_running(self): # pragma: no cover
raise NotImplementedError()
def send(self, message, on_invocation = None): # pragma: no cover
raise NotImplementedError()

@ -0,0 +1,8 @@
from enum import Enum
class ConnectionState(Enum):
connecting = 0
connected = 1
reconnecting = 2
disconnected = 4

@ -0,0 +1,87 @@
import threading
import time
from enum import Enum
class ConnectionStateChecker(object):
def __init__(
self,
ping_function,
keep_alive_interval,
sleep=1):
self.sleep = sleep
self.keep_alive_interval = keep_alive_interval
self.last_message = time.time()
self.ping_function = ping_function
self.running = False
self._thread = None
def start(self):
self.running = True
self._thread = threading.Thread(target=self.run)
self._thread.daemon = True
self._thread.start()
def run(self):
while self.running:
time.sleep(self.sleep)
time_without_messages = time.time() - self.last_message
if self.keep_alive_interval < time_without_messages:
self.ping_function()
def stop(self):
self.running = False
class ReconnectionType(Enum):
raw = 0 # Reconnection with max reconnections and constant sleep time
interval = 1 # variable sleep time
class ReconnectionHandler(object):
def __init__(self):
self.reconnecting = False
self.attempt_number = 0
self.last_attempt = time.time()
def next(self):
raise NotImplementedError()
def reset(self):
self.attempt_number = 0
self.reconnecting = False
class RawReconnectionHandler(ReconnectionHandler):
def __init__(self, sleep_time, max_attempts):
super(RawReconnectionHandler, self).__init__()
self.sleep_time = sleep_time
self.max_reconnection_attempts = max_attempts
def next(self):
self.reconnecting = True
if self.max_reconnection_attempts is not None:
if self.attempt_number <= self.max_reconnection_attempts:
self.attempt_number += 1
return self.sleep_time
else:
raise ValueError(
"Max attemps reached {0}"
.format(self.max_reconnection_attempts))
else: # Infinite reconnect
return self.sleep_time
class IntervalReconnectionHandler(ReconnectionHandler):
def __init__(self, intervals):
super(IntervalReconnectionHandler, self).__init__()
self._intervals = intervals
def next(self):
self.reconnecting = True
index = self.attempt_number
self.attempt_number += 1
if index >= len(self._intervals):
raise ValueError(
"Max intervals reached {0}".format(self._intervals))
return self._intervals[index]

@ -0,0 +1,240 @@
import websocket
import threading
import requests
import traceback
import uuid
import time
import ssl
from .reconnection import ConnectionStateChecker
from .connection import ConnectionState
from ...messages.ping_message import PingMessage
from ...hub.errors import HubError, HubConnectionError, UnAuthorizedHubError
from ...protocol.messagepack_protocol import MessagePackHubProtocol
from ...protocol.json_hub_protocol import JsonHubProtocol
from ..base_transport import BaseTransport
from ...helpers import Helpers
class WebsocketTransport(BaseTransport):
def __init__(self,
url="",
headers={},
keep_alive_interval=15,
reconnection_handler=None,
verify_ssl=False,
skip_negotiation=False,
enable_trace=False,
**kwargs):
super(WebsocketTransport, self).__init__(**kwargs)
self._ws = None
self.enable_trace = enable_trace
self._thread = None
self.skip_negotiation = skip_negotiation
self.url = url
self.headers = headers
self.handshake_received = False
self.token = None # auth
self.state = ConnectionState.disconnected
self.connection_alive = False
self._thread = None
self._ws = None
self.verify_ssl = verify_ssl
self.connection_checker = ConnectionStateChecker(
lambda: self.send(PingMessage()),
keep_alive_interval
)
self.reconnection_handler = reconnection_handler
if len(self.logger.handlers) > 0:
websocket.enableTrace(self.enable_trace, self.logger.handlers[0])
def is_running(self):
return self.state != ConnectionState.disconnected
def stop(self):
if self.state == ConnectionState.connected:
self.connection_checker.stop()
self._ws.close()
self.state = ConnectionState.disconnected
self.handshake_received = False
def start(self):
if not self.skip_negotiation:
self.negotiate()
if self.state == ConnectionState.connected:
self.logger.warning("Already connected unable to start")
return False
self.state = ConnectionState.connecting
self.logger.debug("start url:" + self.url)
self._ws = websocket.WebSocketApp(
self.url,
header=self.headers,
on_message=self.on_message,
on_error=self.on_socket_error,
on_close=self.on_close,
on_open=self.on_open,
)
self._thread = threading.Thread(
target=lambda: self._ws.run_forever(
sslopt={"cert_reqs": ssl.CERT_NONE}
if not self.verify_ssl else {}
))
self._thread.daemon = True
self._thread.start()
return True
def negotiate(self):
negotiate_url = Helpers.get_negotiate_url(self.url)
self.logger.debug("Negotiate url:{0}".format(negotiate_url))
response = requests.post(
negotiate_url, headers=self.headers, verify=self.verify_ssl)
self.logger.debug(
"Response status code{0}".format(response.status_code))
if response.status_code != 200:
raise HubError(response.status_code)\
if response.status_code != 401 else UnAuthorizedHubError()
data = response.json()
if "connectionId" in data.keys():
self.url = Helpers.encode_connection_id(
self.url, data["connectionId"])
# Azure
if 'url' in data.keys() and 'accessToken' in data.keys():
Helpers.get_logger().debug(
"Azure url, reformat headers, token and url {0}".format(data))
self.url = data["url"]\
if data["url"].startswith("ws") else\
Helpers.http_to_websocket(data["url"])
self.token = data["accessToken"]
self.headers = {"Authorization": "Bearer " + self.token}
def evaluate_handshake(self, message):
self.logger.debug("Evaluating handshake {0}".format(message))
msg, messages = self.protocol.decode_handshake(message)
if msg.error is None or msg.error == "":
self.handshake_received = True
self.state = ConnectionState.connected
if self.reconnection_handler is not None:
self.reconnection_handler.reconnecting = False
if not self.connection_checker.running:
self.connection_checker.start()
else:
self.logger.error(msg.error)
self.on_socket_error(msg.error)
self.stop()
raise ValueError("Handshake error {0}".format(msg.error))
return messages
def on_open(self):
self.logger.debug("-- web socket open --")
msg = self.protocol.handshake_message()
self.send(msg)
def on_close(self):
self.logger.debug("-- web socket close --")
self.state = ConnectionState.disconnected
if self._on_close is not None and callable(self._on_close):
self._on_close()
def on_socket_error(self, error):
"""
Throws error related on
https://github.com/websocket-client/websocket-client/issues/449
Args:
error ([type]): [description]
Raises:
HubError: [description]
"""
self.logger.debug("-- web socket error --")
if (type(error) is AttributeError and
"'NoneType' object has no attribute 'connected'"
in str(error)):
url = "https://github.com/websocket-client" +\
"/websocket-client/issues/449"
self.logger.warning(
"Websocket closing error: issue" +
url)
self._on_close()
else:
self.logger.error(traceback.format_exc(5, True))
self.logger.error("{0} {1}".format(self, error))
self.logger.error("{0} {1}".format(error, type(error)))
self._on_close()
raise HubError(error)
def on_message(self, raw_message):
self.logger.debug("Message received{0}".format(raw_message))
self.connection_checker.last_message = time.time()
if not self.handshake_received:
messages = self.evaluate_handshake(raw_message)
if self._on_open is not None and callable(self._on_open):
self.state = ConnectionState.connected
self._on_open()
if len(messages) > 0:
return self._on_message(messages)
return []
return self._on_message(
self.protocol.parse_messages(raw_message))
def send(self, message):
self.logger.debug("Sending message {0}".format(message))
try:
self._ws.send(
self.protocol.encode(message),
opcode=0x2
if type(self.protocol) == MessagePackHubProtocol else
0x1)
self.connection_checker.last_message = time.time()
if self.reconnection_handler is not None:
self.reconnection_handler.reset()
except (
websocket._exceptions.WebSocketConnectionClosedException,
OSError) as ex:
self.handshake_received = False
self.logger.warning("Connection closed {0}".format(ex))
self.state = ConnectionState.disconnected
if self.reconnection_handler is None:
if self._on_close is not None and\
callable(self._on_close):
self._on_close()
raise ValueError(str(ex))
# Connection closed
self.handle_reconnect()
except Exception as ex:
raise ex
def handle_reconnect(self):
self.reconnection_handler.reconnecting = True
try:
self.stop()
self.start()
except Exception as ex:
self.logger.error(ex)
sleep_time = self.reconnection_handler.next()
threading.Thread(
target=self.deferred_reconnect,
args=(sleep_time,)
).start()
def deferred_reconnect(self, sleep_time):
time.sleep(sleep_time)
try:
if not self.connection_alive:
self.send(PingMessage())
except Exception as ex:
self.logger.error(ex)
self.reconnection_handler.reconnecting = False
self.connection_alive = False

@ -0,0 +1,196 @@
#!/usr/bin/env python
"""client library for iterating over http Server Sent Event (SSE) streams"""
#
# Distributed under the terms of the MIT license.
#
from __future__ import unicode_literals
import codecs
import re
import time
import warnings
import six
import requests
__version__ = '0.0.27'
# Technically, we should support streams that mix line endings. This regex,
# however, assumes that a system will provide consistent line endings.
end_of_field = re.compile(r'\r\n\r\n|\r\r|\n\n')
class SSEClient(object):
def __init__(self, url, last_id=None, retry=3000, session=None, chunk_size=1024, **kwargs):
self.url = url
self.last_id = last_id
self.retry = retry
self.chunk_size = chunk_size
# Optional support for passing in a requests.Session()
self.session = session
# Any extra kwargs will be fed into the requests.get call later.
self.requests_kwargs = kwargs
# The SSE spec requires making requests with Cache-Control: nocache
if 'headers' not in self.requests_kwargs:
self.requests_kwargs['headers'] = {}
self.requests_kwargs['headers']['Cache-Control'] = 'no-cache'
# The 'Accept' header is not required, but explicit > implicit
self.requests_kwargs['headers']['Accept'] = 'text/event-stream'
# Keep data here as it streams in
self.buf = ''
self._connect()
def _connect(self):
if self.last_id:
self.requests_kwargs['headers']['Last-Event-ID'] = self.last_id
# Use session if set. Otherwise fall back to requests module.
requester = self.session or requests
try:
self.resp = requester.get(self.url, stream=True, **self.requests_kwargs)
except requests.exceptions.ConnectionError:
raise requests.exceptions.ConnectionError
else:
self.resp_iterator = self.iter_content()
encoding = self.resp.encoding or self.resp.apparent_encoding
self.decoder = codecs.getincrementaldecoder(encoding)(errors='replace')
finally:
# TODO: Ensure we're handling redirects. Might also stick the 'origin'
# attribute on Events like the Javascript spec requires.
self.resp.raise_for_status()
def iter_content(self):
def generate():
while True:
if hasattr(self.resp.raw, '_fp') and \
hasattr(self.resp.raw._fp, 'fp') and \
hasattr(self.resp.raw._fp.fp, 'read1'):
chunk = self.resp.raw._fp.fp.read1(self.chunk_size)
else:
# _fp is not available, this means that we cannot use short
# reads and this will block until the full chunk size is
# actually read
chunk = self.resp.raw.read(self.chunk_size)
if not chunk:
break
yield chunk
return generate()
def _event_complete(self):
return re.search(end_of_field, self.buf) is not None
def __iter__(self):
return self
def __next__(self):
while not self._event_complete():
try:
next_chunk = next(self.resp_iterator)
if not next_chunk:
raise EOFError()
self.buf += self.decoder.decode(next_chunk)
except (StopIteration, requests.RequestException, EOFError, six.moves.http_client.IncompleteRead) as e:
# print(e)
time.sleep(self.retry / 1000.0)
self._connect()
# The SSE spec only supports resuming from a whole message, so
# if we have half a message we should throw it out.
head, sep, tail = self.buf.rpartition('\n')
self.buf = head + sep
continue
# Split the complete event (up to the end_of_field) into event_string,
# and retain anything after the current complete event in self.buf
# for next time.
(event_string, self.buf) = re.split(end_of_field, self.buf, maxsplit=1)
msg = Event.parse(event_string)
# If the server requests a specific retry delay, we need to honor it.
if msg.retry:
self.retry = msg.retry
# last_id should only be set if included in the message. It's not
# forgotten if a message omits it.
if msg.id:
self.last_id = msg.id
return msg
if six.PY2:
next = __next__
class Event(object):
sse_line_pattern = re.compile('(?P<name>[^:]*):?( ?(?P<value>.*))?')
def __init__(self, data='', event='message', id=None, retry=None):
assert isinstance(data, six.string_types), "Data must be text"
self.data = data
self.event = event
self.id = id
self.retry = retry
def dump(self):
lines = []
if self.id:
lines.append('id: %s' % self.id)
# Only include an event line if it's not the default already.
if self.event != 'message':
lines.append('event: %s' % self.event)
if self.retry:
lines.append('retry: %s' % self.retry)
lines.extend('data: %s' % d for d in self.data.split('\n'))
return '\n'.join(lines) + '\n\n'
@classmethod
def parse(cls, raw):
"""
Given a possibly-multiline string representing an SSE message, parse it
and return a Event object.
"""
msg = cls()
for line in raw.splitlines():
m = cls.sse_line_pattern.match(line)
if m is None:
# Malformed line. Discard but warn.
warnings.warn('Invalid SSE line: "%s"' % line, SyntaxWarning)
continue
name = m.group('name')
if name == '':
# line began with a ":", so is a comment. Ignore
continue
value = m.group('value')
if name == 'data':
# If we already have some data, then join to it with a newline.
# Else this is it.
if msg.data:
msg.data = '%s\n%s' % (msg.data, value)
else:
msg.data = value
elif name == 'event':
msg.event = value
elif name == 'id':
msg.id = value
elif name == 'retry':
msg.retry = int(value)
return msg
def __str__(self):
return self.data

@ -0,0 +1,38 @@
# Copyright 2018 Donald Stufft and individual contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = (
"__title__",
"__summary__",
"__uri__",
"__version__",
"__author__",
"__email__",
"__license__",
"__copyright__",
)
__copyright__ = "Copyright 2019 Donald Stufft and individual contributors"
import importlib_metadata
metadata = importlib_metadata.metadata("twine")
__title__ = metadata["name"]
__summary__ = metadata["summary"]
__uri__ = metadata["home-page"]
__version__ = metadata["version"]
__author__ = metadata["author"]
__email__ = metadata["author-email"]
__license__ = metadata["license"]

@ -0,0 +1,53 @@
#!/usr/bin/env python3
# Copyright 2013 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import http
import sys
from typing import Any
import colorama
import requests
from twine import cli
from twine import exceptions
def main() -> Any:
try:
result = cli.dispatch(sys.argv[1:])
except requests.HTTPError as exc:
status_code = exc.response.status_code
status_phrase = http.HTTPStatus(status_code).phrase
result = (
f"{exc.__class__.__name__}: {status_code} {status_phrase} "
f"from {exc.response.url}\n"
f"{exc.response.reason}"
)
except exceptions.TwineException as exc:
result = f"{exc.__class__.__name__}: {exc.args[0]}"
return _format_error(result) if isinstance(result, str) else result
def _format_error(message: str) -> str:
pre_style, post_style = "", ""
if not cli.args.no_color:
colorama.init()
pre_style, post_style = colorama.Fore.RED, colorama.Style.RESET_ALL
return f"{pre_style}{message}{post_style}"
if __name__ == "__main__":
sys.exit(main())

@ -0,0 +1,100 @@
import functools
import getpass
import logging
import warnings
from typing import Callable, Optional, Type, cast
import keyring
from twine import exceptions
from twine import utils
logger = logging.getLogger(__name__)
class CredentialInput:
def __init__(
self, username: Optional[str] = None, password: Optional[str] = None
) -> None:
self.username = username
self.password = password
class Resolver:
def __init__(self, config: utils.RepositoryConfig, input: CredentialInput) -> None:
self.config = config
self.input = input
@classmethod
def choose(cls, interactive: bool) -> Type["Resolver"]:
return cls if interactive else Private
@property # type: ignore # https://github.com/python/mypy/issues/1362
@functools.lru_cache()
def username(self) -> Optional[str]:
return utils.get_userpass_value(
self.input.username,
self.config,
key="username",
prompt_strategy=self.username_from_keyring_or_prompt,
)
@property # type: ignore # https://github.com/python/mypy/issues/1362
@functools.lru_cache()
def password(self) -> Optional[str]:
return utils.get_userpass_value(
self.input.password,
self.config,
key="password",
prompt_strategy=self.password_from_keyring_or_prompt,
)
@property
def system(self) -> Optional[str]:
return self.config["repository"]
def get_username_from_keyring(self) -> Optional[str]:
try:
system = cast(str, self.system)
creds = keyring.get_credential(system, None)
if creds:
return cast(str, creds.username)
except AttributeError:
# To support keyring prior to 15.2
pass
except Exception as exc:
warnings.warn(str(exc))
return None
def get_password_from_keyring(self) -> Optional[str]:
try:
system = cast(str, self.system)
username = cast(str, self.username)
return cast(str, keyring.get_password(system, username))
except Exception as exc:
warnings.warn(str(exc))
return None
def username_from_keyring_or_prompt(self) -> str:
username = self.get_username_from_keyring()
if username:
logger.info("username set from keyring")
return username
return self.prompt("username", input)
def password_from_keyring_or_prompt(self) -> str:
password = self.get_password_from_keyring()
if password:
logger.info("password set from keyring")
return password
return self.prompt("password", getpass.getpass)
def prompt(self, what: str, how: Callable[..., str]) -> str:
return how(f"Enter your {what}: ")
class Private(Resolver):
def prompt(self, what: str, how: Optional[Callable[..., str]] = None) -> str:
raise exceptions.NonInteractive(f"Credential not found for {what}.")

@ -0,0 +1,71 @@
# Copyright 2013 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import Any, List, Tuple
from importlib_metadata import entry_points
from importlib_metadata import version
import twine
args = argparse.Namespace()
def list_dependencies_and_versions() -> List[Tuple[str, str]]:
deps = (
"importlib_metadata",
"pkginfo",
"requests",
"requests-toolbelt",
"tqdm",
)
return [(dep, version(dep)) for dep in deps] # type: ignore[no-untyped-call] # python/importlib_metadata#288 # noqa: E501
def dep_versions() -> str:
return ", ".join(
"{}: {}".format(*dependency) for dependency in list_dependencies_and_versions()
)
def dispatch(argv: List[str]) -> Any:
registered_commands = entry_points(group="twine.registered_commands")
parser = argparse.ArgumentParser(prog="twine")
parser.add_argument(
"--version",
action="version",
version="%(prog)s version {} ({})".format(twine.__version__, dep_versions()),
)
parser.add_argument(
"--no-color",
default=False,
required=False,
action="store_true",
help="disable colored output",
)
parser.add_argument(
"command",
choices=registered_commands.names,
)
parser.add_argument(
"args",
help=argparse.SUPPRESS,
nargs=argparse.REMAINDER,
)
parser.parse_args(argv, namespace=args)
main = registered_commands[args.command].load()
return main(args.args)

@ -0,0 +1,48 @@
# Copyright 2013 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os.path
from typing import List
from twine import exceptions
__all__: List[str] = []
def _group_wheel_files_first(files: List[str]) -> List[str]:
if not any(fname for fname in files if fname.endswith(".whl")):
# Return early if there's no wheel files
return files
files.sort(key=lambda x: -1 if x.endswith(".whl") else 0)
return files
def _find_dists(dists: List[str]) -> List[str]:
uploads = []
for filename in dists:
if os.path.exists(filename):
uploads.append(filename)
continue
# The filename didn't exist so it may be a glob
files = glob.glob(filename)
# If nothing matches, files is []
if not files:
raise exceptions.InvalidDistribution(
"Cannot find file (or expand pattern): '%s'" % filename
)
# Otherwise, files will be filenames that exist
uploads.extend(files)
return _group_wheel_files_first(uploads)

@ -0,0 +1,167 @@
# Copyright 2018 Dustin Ingram
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import cgi
import io
import re
import sys
import textwrap
from typing import IO, List, Optional, Tuple, cast
import readme_renderer.rst
from twine import commands
from twine import package as package_file
_RENDERERS = {
None: readme_renderer.rst, # Default if description_content_type is None
"text/plain": None, # Rendering cannot fail
"text/x-rst": readme_renderer.rst,
"text/markdown": None, # Rendering cannot fail
}
# Regular expression used to capture and reformat docutils warnings into
# something that a human can understand. This is loosely borrowed from
# Sphinx: https://github.com/sphinx-doc/sphinx/blob
# /c35eb6fade7a3b4a6de4183d1dd4196f04a5edaf/sphinx/util/docutils.py#L199
_REPORT_RE = re.compile(
r"^<string>:(?P<line>(?:\d+)?): "
r"\((?P<level>DEBUG|INFO|WARNING|ERROR|SEVERE)/(\d+)?\) "
r"(?P<message>.*)",
re.DOTALL | re.MULTILINE,
)
class _WarningStream:
def __init__(self) -> None:
self.output = io.StringIO()
def write(self, text: str) -> None:
matched = _REPORT_RE.search(text)
if not matched:
self.output.write(text)
return
self.output.write(
"line {line}: {level_text}: {message}\n".format(
level_text=matched.group("level").capitalize(),
line=matched.group("line"),
message=matched.group("message").rstrip("\r\n"),
)
)
def __str__(self) -> str:
return self.output.getvalue()
def _check_file(
filename: str, render_warning_stream: _WarningStream
) -> Tuple[List[str], bool]:
"""Check given distribution."""
warnings = []
is_ok = True
package = package_file.PackageFile.from_filename(filename, comment=None)
metadata = package.metadata_dictionary()
description = cast(Optional[str], metadata["description"])
description_content_type = cast(Optional[str], metadata["description_content_type"])
if description_content_type is None:
warnings.append(
"`long_description_content_type` missing. defaulting to `text/x-rst`."
)
description_content_type = "text/x-rst"
content_type, params = cgi.parse_header(description_content_type)
renderer = _RENDERERS.get(content_type, _RENDERERS[None])
if description in {None, "UNKNOWN\n\n\n"}:
warnings.append("`long_description` missing.")
elif renderer:
rendering_result = renderer.render(
description, stream=render_warning_stream, **params
)
if rendering_result is None:
is_ok = False
return warnings, is_ok
def check(
dists: List[str],
output_stream: IO[str] = sys.stdout,
strict: bool = False,
) -> bool:
uploads = [i for i in commands._find_dists(dists) if not i.endswith(".asc")]
if not uploads: # Return early, if there are no files to check.
output_stream.write("No files to check.\n")
return False
failure = False
for filename in uploads:
output_stream.write("Checking %s: " % filename)
render_warning_stream = _WarningStream()
warnings, is_ok = _check_file(filename, render_warning_stream)
# Print the status and/or error
if not is_ok:
failure = True
output_stream.write("FAILED\n")
error_text = (
"`long_description` has syntax errors in markup and "
"would not be rendered on PyPI.\n"
)
output_stream.write(textwrap.indent(error_text, " "))
output_stream.write(textwrap.indent(str(render_warning_stream), " "))
elif warnings:
if strict:
failure = True
output_stream.write("FAILED, due to warnings\n")
else:
output_stream.write("PASSED, with warnings\n")
else:
output_stream.write("PASSED\n")
# Print warnings after the status and/or error
for message in warnings:
output_stream.write(" warning: " + message + "\n")
return failure
def main(args: List[str]) -> bool:
parser = argparse.ArgumentParser(prog="twine check")
parser.add_argument(
"dists",
nargs="+",
metavar="dist",
help="The distribution files to check, usually dist/*",
)
parser.add_argument(
"--strict",
action="store_true",
default=False,
required=False,
help="Fail on warnings",
)
parsed_args = parser.parse_args(args)
# Call the check function with the arguments from the command line
return check(parsed_args.dists, strict=parsed_args.strict)

@ -0,0 +1,63 @@
# Copyright 2015 Ian Cordasco
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os.path
from typing import List, cast
from twine import exceptions
from twine import package as package_file
from twine import settings
def register(register_settings: settings.Settings, package: str) -> None:
repository_url = cast(str, register_settings.repository_config["repository"])
print(f"Registering package to {repository_url}")
repository = register_settings.create_repository()
if not os.path.exists(package):
raise exceptions.PackageNotFound(
f'"{package}" does not exist on the file system.'
)
resp = repository.register(
package_file.PackageFile.from_filename(package, register_settings.comment)
)
repository.close()
if resp.is_redirect:
raise exceptions.RedirectDetected.from_args(
repository_url,
resp.headers["location"],
)
resp.raise_for_status()
def main(args: List[str]) -> None:
parser = argparse.ArgumentParser(
prog="twine register",
description="register operation is not required with PyPI.org",
)
settings.Settings.register_argparse_arguments(parser)
parser.add_argument(
"package",
metavar="package",
help="File from which we read the package metadata.",
)
parsed_args = parser.parse_args(args)
register_settings = settings.Settings.from_argparse(parsed_args)
# Call the register function with the args from the command line
register(register_settings, parsed_args.package)

@ -0,0 +1,154 @@
# Copyright 2013 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os.path
from typing import Dict, List, cast
import requests
from twine import commands
from twine import exceptions
from twine import package as package_file
from twine import settings
from twine import utils
logger = logging.getLogger(__name__)
def skip_upload(
response: requests.Response, skip_existing: bool, package: package_file.PackageFile
) -> bool:
if not skip_existing:
return False
status = response.status_code
reason = getattr(response, "reason", "").lower()
text = getattr(response, "text", "").lower()
# NOTE(sigmavirus24): PyPI presently returns a 400 status code with the
# error message in the reason attribute. Other implementations return a
# 403 or 409 status code.
return (
# pypiserver (https://pypi.org/project/pypiserver)
status == 409
# PyPI / TestPyPI
or (status == 400 and "already exist" in reason)
# Nexus Repository OSS (https://www.sonatype.com/nexus-repository-oss)
or (status == 400 and any("updating asset" in x for x in [reason, text]))
# Artifactory (https://jfrog.com/artifactory/)
or (status == 403 and "overwrite artifact" in text)
# Gitlab Enterprise Edition (https://about.gitlab.com)
or (status == 400 and "already been taken" in text)
)
def _make_package(
filename: str, signatures: Dict[str, str], upload_settings: settings.Settings
) -> package_file.PackageFile:
"""Create and sign a package, based off of filename, signatures and settings."""
package = package_file.PackageFile.from_filename(filename, upload_settings.comment)
signed_name = package.signed_basefilename
if signed_name in signatures:
package.add_gpg_signature(signatures[signed_name], signed_name)
elif upload_settings.sign:
package.sign(upload_settings.sign_with, upload_settings.identity)
file_size = utils.get_file_size(package.filename)
logger.info(f" {package.filename} ({file_size})")
if package.gpg_signature:
logger.info(f" Signed with {package.signed_filename}")
return package
def upload(upload_settings: settings.Settings, dists: List[str]) -> None:
dists = commands._find_dists(dists)
# Determine if the user has passed in pre-signed distributions
signatures = {os.path.basename(d): d for d in dists if d.endswith(".asc")}
uploads = [i for i in dists if not i.endswith(".asc")]
upload_settings.check_repository_url()
repository_url = cast(str, upload_settings.repository_config["repository"])
print(f"Uploading distributions to {repository_url}")
packages_to_upload = [
_make_package(filename, signatures, upload_settings) for filename in uploads
]
repository = upload_settings.create_repository()
uploaded_packages = []
for package in packages_to_upload:
skip_message = " Skipping {} because it appears to already exist".format(
package.basefilename
)
# Note: The skip_existing check *needs* to be first, because otherwise
# we're going to generate extra HTTP requests against a hardcoded
# URL for no reason.
if upload_settings.skip_existing and repository.package_is_uploaded(package):
print(skip_message)
continue
resp = repository.upload(package)
# Bug 92. If we get a redirect we should abort because something seems
# funky. The behaviour is not well defined and redirects being issued
# by PyPI should never happen in reality. This should catch malicious
# redirects as well.
if resp.is_redirect:
raise exceptions.RedirectDetected.from_args(
repository_url,
resp.headers["location"],
)
if skip_upload(resp, upload_settings.skip_existing, package):
print(skip_message)
continue
utils.check_status_code(resp, upload_settings.verbose)
uploaded_packages.append(package)
release_urls = repository.release_urls(uploaded_packages)
if release_urls:
print("\nView at:")
for url in release_urls:
print(url)
# Bug 28. Try to silence a ResourceWarning by clearing the connection
# pool.
repository.close()
def main(args: List[str]) -> None:
parser = argparse.ArgumentParser(prog="twine upload")
settings.Settings.register_argparse_arguments(parser)
parser.add_argument(
"dists",
nargs="+",
metavar="dist",
help="The distribution files to upload to the repository "
"(package index). Usually dist/* . May additionally contain "
"a .asc file to include an existing signature with the "
"file upload.",
)
parsed_args = parser.parse_args(args)
upload_settings = settings.Settings.from_argparse(parsed_args)
# Call the upload function with the arguments from the command line
return upload(upload_settings, parsed_args.dists)

@ -0,0 +1,123 @@
"""Module containing exceptions raised by twine."""
# Copyright 2015 Ian Stapleton Cordasco
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class TwineException(Exception):
"""Base class for all exceptions raised by twine."""
pass
class RedirectDetected(TwineException):
"""A redirect was detected that the user needs to resolve.
In some cases, requests refuses to issue a new POST request after a
redirect. In order to prevent a confusing user experience, we raise this
exception to allow users to know the index they're uploading to is
redirecting them.
"""
@classmethod
def from_args(cls, repository_url: str, redirect_url: str) -> "RedirectDetected":
msg = "\n".join(
[
"{} attempted to redirect to {}.".format(repository_url, redirect_url),
"If you trust these URLs, set {} as your repository URL.".format(
redirect_url
),
"Aborting.",
]
)
return cls(msg)
class PackageNotFound(TwineException):
"""A package file was provided that could not be found on the file system.
This is only used when attempting to register a package_file.
"""
pass
class UploadToDeprecatedPyPIDetected(TwineException):
"""An upload attempt was detected to deprecated PyPI domains.
The sites pypi.python.org and testpypi.python.org are deprecated.
"""
@classmethod
def from_args(
cls, target_url: str, default_url: str, test_url: str
) -> "UploadToDeprecatedPyPIDetected":
"""Return an UploadToDeprecatedPyPIDetected instance."""
return cls(
"You're trying to upload to the legacy PyPI site '{}'. "
"Uploading to those sites is deprecated. \n "
"The new sites are pypi.org and test.pypi.org. Try using "
"{} (or {}) to upload your packages instead. "
"These are the default URLs for Twine now. \n More at "
"https://packaging.python.org/guides/migrating-to-pypi-org/"
" .".format(target_url, default_url, test_url)
)
class UnreachableRepositoryURLDetected(TwineException):
"""An upload attempt was detected to a URL without a protocol prefix.
All repository URLs must have a protocol (e.g., ``https://``).
"""
pass
class InvalidSigningConfiguration(TwineException):
"""Both the sign and identity parameters must be present."""
pass
class InvalidSigningExecutable(TwineException):
"""Signing executable must be installed on system."""
pass
class InvalidConfiguration(TwineException):
"""Raised when configuration is invalid."""
pass
class InvalidDistribution(TwineException):
"""Raised when a distribution is invalid."""
pass
class NonInteractive(TwineException):
"""Raised in non-interactive mode when credentials could not be found."""
pass
class InvalidPyPIUploadURL(TwineException):
"""Repository configuration tries to use PyPI with an incorrect URL.
For example, https://pypi.org instead of https://upload.pypi.org/legacy.
"""
pass

@ -0,0 +1,291 @@
# Copyright 2015 Ian Cordasco
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import io
import os
import re
import subprocess
from typing import Dict, NamedTuple, Optional, Sequence, Tuple, Union
import importlib_metadata
import pkginfo
from twine import exceptions
from twine import wheel
from twine import wininst
DIST_TYPES = {
"bdist_wheel": wheel.Wheel,
"bdist_wininst": wininst.WinInst,
"bdist_egg": pkginfo.BDist,
"sdist": pkginfo.SDist,
}
DIST_EXTENSIONS = {
".whl": "bdist_wheel",
".exe": "bdist_wininst",
".egg": "bdist_egg",
".tar.bz2": "sdist",
".tar.gz": "sdist",
".zip": "sdist",
}
MetadataValue = Union[str, Sequence[str]]
def _safe_name(name: str) -> str:
"""Convert an arbitrary string to a standard distribution name.
Any runs of non-alphanumeric/. characters are replaced with a single '-'.
Copied from pkg_resources.safe_name for compatibility with warehouse.
See https://github.com/pypa/twine/issues/743.
"""
return re.sub("[^A-Za-z0-9.]+", "-", name)
class PackageFile:
def __init__(
self,
filename: str,
comment: Optional[str],
metadata: pkginfo.Distribution,
python_version: Optional[str],
filetype: Optional[str],
) -> None:
self.filename = filename
self.basefilename = os.path.basename(filename)
self.comment = comment
self.metadata = metadata
self.python_version = python_version
self.filetype = filetype
self.safe_name = _safe_name(metadata.name)
self.signed_filename = self.filename + ".asc"
self.signed_basefilename = self.basefilename + ".asc"
self.gpg_signature: Optional[Tuple[str, bytes]] = None
hasher = HashManager(filename)
hasher.hash()
hexdigest = hasher.hexdigest()
self.md5_digest = hexdigest.md5
self.sha2_digest = hexdigest.sha2
self.blake2_256_digest = hexdigest.blake2
@classmethod
def from_filename(cls, filename: str, comment: Optional[str]) -> "PackageFile":
# Extract the metadata from the package
for ext, dtype in DIST_EXTENSIONS.items():
if filename.endswith(ext):
try:
meta = DIST_TYPES[dtype](filename)
except EOFError:
raise exceptions.InvalidDistribution(
"Invalid distribution file: '%s'" % os.path.basename(filename)
)
else:
break
else:
raise exceptions.InvalidDistribution(
"Unknown distribution format: '%s'" % os.path.basename(filename)
)
# If pkginfo encounters a metadata version it doesn't support, it may
# give us back empty metadata. At the very least, we should have a name
# and version
if not (meta.name and meta.version):
raise exceptions.InvalidDistribution(
"Invalid distribution metadata. Try upgrading twine if possible."
)
py_version: Optional[str]
if dtype == "bdist_egg":
(dist,) = importlib_metadata.Distribution.discover( # type: ignore[no-untyped-call] # python/importlib_metadata#288 # noqa: E501
path=[filename]
)
py_version = dist.metadata["Version"]
elif dtype == "bdist_wheel":
py_version = meta.py_version
elif dtype == "bdist_wininst":
py_version = meta.py_version
else:
py_version = None
return cls(filename, comment, meta, py_version, dtype)
def metadata_dictionary(self) -> Dict[str, MetadataValue]:
meta = self.metadata
data = {
# identify release
"name": self.safe_name,
"version": meta.version,
# file content
"filetype": self.filetype,
"pyversion": self.python_version,
# additional meta-data
"metadata_version": meta.metadata_version,
"summary": meta.summary,
"home_page": meta.home_page,
"author": meta.author,
"author_email": meta.author_email,
"maintainer": meta.maintainer,
"maintainer_email": meta.maintainer_email,
"license": meta.license,
"description": meta.description,
"keywords": meta.keywords,
"platform": meta.platforms,
"classifiers": meta.classifiers,
"download_url": meta.download_url,
"supported_platform": meta.supported_platforms,
"comment": self.comment,
"md5_digest": self.md5_digest,
"sha256_digest": self.sha2_digest,
"blake2_256_digest": self.blake2_256_digest,
# PEP 314
"provides": meta.provides,
"requires": meta.requires,
"obsoletes": meta.obsoletes,
# Metadata 1.2
"project_urls": meta.project_urls,
"provides_dist": meta.provides_dist,
"obsoletes_dist": meta.obsoletes_dist,
"requires_dist": meta.requires_dist,
"requires_external": meta.requires_external,
"requires_python": meta.requires_python,
# Metadata 2.1
"provides_extras": meta.provides_extras,
"description_content_type": meta.description_content_type,
}
if self.gpg_signature is not None:
data["gpg_signature"] = self.gpg_signature
return data
def add_gpg_signature(
self, signature_filepath: str, signature_filename: str
) -> None:
if self.gpg_signature is not None:
raise exceptions.InvalidDistribution("GPG Signature can only be added once")
with open(signature_filepath, "rb") as gpg:
self.gpg_signature = (signature_filename, gpg.read())
def sign(self, sign_with: str, identity: Optional[str]) -> None:
print(f"Signing {self.basefilename}")
gpg_args: Tuple[str, ...] = (sign_with, "--detach-sign")
if identity:
gpg_args += ("--local-user", identity)
gpg_args += ("-a", self.filename)
self.run_gpg(gpg_args)
self.add_gpg_signature(self.signed_filename, self.signed_basefilename)
@classmethod
def run_gpg(cls, gpg_args: Tuple[str, ...]) -> None:
try:
subprocess.check_call(gpg_args)
return
except FileNotFoundError:
if gpg_args[0] != "gpg":
raise exceptions.InvalidSigningExecutable(
"{} executable not available.".format(gpg_args[0])
)
print("gpg executable not available. Attempting fallback to gpg2.")
try:
subprocess.check_call(("gpg2",) + gpg_args[1:])
except FileNotFoundError:
print("gpg2 executable not available.")
raise exceptions.InvalidSigningExecutable(
"'gpg' or 'gpg2' executables not available. "
"Try installing one of these or specifying an executable "
"with the --sign-with flag."
)
class Hexdigest(NamedTuple):
md5: Optional[str]
sha2: Optional[str]
blake2: Optional[str]
class HashManager:
"""Manage our hashing objects for simplicity.
This will also allow us to better test this logic.
"""
def __init__(self, filename: str) -> None:
"""Initialize our manager and hasher objects."""
self.filename = filename
self._md5_hasher = None
try:
self._md5_hasher = hashlib.md5()
except ValueError:
# FIPs mode disables MD5
pass
self._sha2_hasher = hashlib.sha256()
self._blake_hasher = None
try:
self._blake_hasher = hashlib.blake2b(digest_size=256 // 8)
except ValueError:
# FIPS mode disables blake2
pass
def _md5_update(self, content: bytes) -> None:
if self._md5_hasher is not None:
self._md5_hasher.update(content)
def _md5_hexdigest(self) -> Optional[str]:
if self._md5_hasher is not None:
return self._md5_hasher.hexdigest()
return None
def _sha2_update(self, content: bytes) -> None:
if self._sha2_hasher is not None:
self._sha2_hasher.update(content)
def _sha2_hexdigest(self) -> Optional[str]:
if self._sha2_hasher is not None:
return self._sha2_hasher.hexdigest()
return None
def _blake_update(self, content: bytes) -> None:
if self._blake_hasher is not None:
self._blake_hasher.update(content)
def _blake_hexdigest(self) -> Optional[str]:
if self._blake_hasher is not None:
return self._blake_hasher.hexdigest()
return None
def hash(self) -> None:
"""Hash the file contents."""
with open(self.filename, "rb") as fp:
for content in iter(lambda: fp.read(io.DEFAULT_BUFFER_SIZE), b""):
self._md5_update(content)
self._sha2_update(content)
self._blake_update(content)
def hexdigest(self) -> Hexdigest:
"""Return the hexdigest for the file."""
return Hexdigest(
self._md5_hexdigest(),
self._sha2_hexdigest(),
self._blake_hexdigest(),
)

@ -0,0 +1,264 @@
# Copyright 2015 Ian Cordasco
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
from typing import Any, Dict, List, Optional, Set, Tuple, cast
import requests
import requests_toolbelt
import tqdm
import urllib3
from requests import adapters
from requests_toolbelt.utils import user_agent
import twine
from twine import package as package_file
KEYWORDS_TO_NOT_FLATTEN = {"gpg_signature", "content"}
LEGACY_PYPI = "https://pypi.python.org/"
LEGACY_TEST_PYPI = "https://testpypi.python.org/"
WAREHOUSE = "https://upload.pypi.org/"
OLD_WAREHOUSE = "https://upload.pypi.io/"
TEST_WAREHOUSE = "https://test.pypi.org/"
WAREHOUSE_WEB = "https://pypi.org/"
logger = logging.getLogger(__name__)
class ProgressBar(tqdm.tqdm):
def update_to(self, n: int) -> None:
"""Update the bar in the way compatible with requests-toolbelt.
This is identical to tqdm.update, except ``n`` will be the current
value - not the delta as tqdm expects.
"""
self.update(n - self.n) # will also do self.n = n
class Repository:
def __init__(
self,
repository_url: str,
username: Optional[str],
password: Optional[str],
disable_progress_bar: bool = False,
) -> None:
self.url = repository_url
self.session = requests.session()
# requests.Session.auth should be Union[None, Tuple[str, str], ...]
# But username or password could be None
# See TODO for utils.RepositoryConfig
self.session.auth = (
(username or "", password or "") if username or password else None
)
logger.info(f"username: {username if username else '<empty>'}")
logger.info(f"password: <{'hidden' if password else 'empty'}>")
self.session.headers["User-Agent"] = self._make_user_agent_string()
for scheme in ("http://", "https://"):
self.session.mount(scheme, self._make_adapter_with_retries())
# Working around https://github.com/python/typing/issues/182
self._releases_json_data: Dict[str, Dict[str, Any]] = {}
self.disable_progress_bar = disable_progress_bar
@staticmethod
def _make_adapter_with_retries() -> adapters.HTTPAdapter:
retry_kwargs = dict(
connect=5,
total=10,
status_forcelist=[500, 501, 502, 503],
)
try:
retry = urllib3.Retry(allowed_methods=["GET"], **retry_kwargs)
except TypeError: # pragma: no cover
# Avoiding DeprecationWarning starting in urllib3 1.26
# Remove when that's the mininum version
retry = urllib3.Retry(method_whitelist=["GET"], **retry_kwargs)
return adapters.HTTPAdapter(max_retries=retry)
@staticmethod
def _make_user_agent_string() -> str:
from twine import cli
dependencies = cli.list_dependencies_and_versions()
user_agent_string = (
user_agent.UserAgentBuilder("twine", twine.__version__)
.include_extras(dependencies)
.include_implementation()
.build()
)
return cast(str, user_agent_string)
def close(self) -> None:
self.session.close()
@staticmethod
def _convert_data_to_list_of_tuples(data: Dict[str, Any]) -> List[Tuple[str, Any]]:
data_to_send = []
for key, value in data.items():
if key in KEYWORDS_TO_NOT_FLATTEN or not isinstance(value, (list, tuple)):
data_to_send.append((key, value))
else:
for item in value:
data_to_send.append((key, item))
return data_to_send
def set_certificate_authority(self, cacert: Optional[str]) -> None:
if cacert:
self.session.verify = cacert
def set_client_certificate(self, clientcert: Optional[str]) -> None:
if clientcert:
self.session.cert = clientcert
def register(self, package: package_file.PackageFile) -> requests.Response:
data = package.metadata_dictionary()
data.update({":action": "submit", "protocol_version": "1"})
print(f"Registering {package.basefilename}")
data_to_send = self._convert_data_to_list_of_tuples(data)
encoder = requests_toolbelt.MultipartEncoder(data_to_send)
resp = self.session.post(
self.url,
data=encoder,
allow_redirects=False,
headers={"Content-Type": encoder.content_type},
)
# Bug 28. Try to silence a ResourceWarning by releasing the socket.
resp.close()
return resp
def _upload(self, package: package_file.PackageFile) -> requests.Response:
data = package.metadata_dictionary()
data.update(
{
# action
":action": "file_upload",
"protocol_version": "1",
}
)
data_to_send = self._convert_data_to_list_of_tuples(data)
print(f"Uploading {package.basefilename}")
with open(package.filename, "rb") as fp:
data_to_send.append(
("content", (package.basefilename, fp, "application/octet-stream"))
)
encoder = requests_toolbelt.MultipartEncoder(data_to_send)
with ProgressBar(
total=encoder.len,
unit="B",
unit_scale=True,
unit_divisor=1024,
miniters=1,
file=sys.stdout,
disable=self.disable_progress_bar,
) as bar:
monitor = requests_toolbelt.MultipartEncoderMonitor(
encoder, lambda monitor: bar.update_to(monitor.bytes_read)
)
resp = self.session.post(
self.url,
data=monitor,
allow_redirects=False,
headers={"Content-Type": monitor.content_type},
)
return resp
def upload(
self, package: package_file.PackageFile, max_redirects: int = 5
) -> requests.Response:
number_of_redirects = 0
while number_of_redirects < max_redirects:
resp = self._upload(package)
if resp.status_code == requests.codes.OK:
return resp
if 500 <= resp.status_code < 600:
number_of_redirects += 1
print(
'Received "{status_code}: {reason}" Package upload '
"appears to have failed. Retry {retry} of "
"{max_redirects}".format(
status_code=resp.status_code,
reason=resp.reason,
retry=number_of_redirects,
max_redirects=max_redirects,
)
)
else:
return resp
return resp
def package_is_uploaded(
self, package: package_file.PackageFile, bypass_cache: bool = False
) -> bool:
# NOTE(sigmavirus24): Not all indices are PyPI and pypi.io doesn't
# have a similar interface for finding the package versions.
if not self.url.startswith((LEGACY_PYPI, WAREHOUSE, OLD_WAREHOUSE)):
return False
safe_name = package.safe_name
releases = None
if not bypass_cache:
releases = self._releases_json_data.get(safe_name)
if releases is None:
url = "{url}pypi/{package}/json".format(package=safe_name, url=LEGACY_PYPI)
headers = {"Accept": "application/json"}
response = self.session.get(url, headers=headers)
if response.status_code == 200:
releases = response.json()["releases"]
else:
releases = {}
self._releases_json_data[safe_name] = releases
packages = releases.get(package.metadata.version, [])
for uploaded_package in packages:
if uploaded_package["filename"] == package.basefilename:
return True
return False
def release_urls(self, packages: List[package_file.PackageFile]) -> Set[str]:
if self.url.startswith(WAREHOUSE):
url = WAREHOUSE_WEB
elif self.url.startswith(TEST_WAREHOUSE):
url = TEST_WAREHOUSE
else:
return set()
return {
"{}project/{}/{}/".format(url, package.safe_name, package.metadata.version)
for package in packages
}
def verify_package_integrity(self, package: package_file.PackageFile) -> None:
# TODO(sigmavirus24): Add a way for users to download the package and
# check it's hash against what it has locally.
pass

@ -0,0 +1,351 @@
"""Module containing logic for handling settings."""
# Copyright 2018 Ian Stapleton Cordasco
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
import logging
import sys
from typing import Any, ContextManager, Optional, cast
from twine import auth
from twine import exceptions
from twine import repository
from twine import utils
class Settings:
"""Object that manages the configuration for Twine.
This object can only be instantiated with keyword arguments.
For example,
.. code-block:: python
Settings(True, username='fakeusername')
Will raise a :class:`TypeError`. Instead, you would want
.. code-block:: python
Settings(sign=True, username='fakeusername')
"""
def __init__(
self,
*,
sign: bool = False,
sign_with: str = "gpg",
identity: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
non_interactive: bool = False,
comment: Optional[str] = None,
config_file: str = "~/.pypirc",
skip_existing: bool = False,
cacert: Optional[str] = None,
client_cert: Optional[str] = None,
repository_name: str = "pypi",
repository_url: Optional[str] = None,
verbose: bool = False,
disable_progress_bar: bool = False,
**ignored_kwargs: Any,
) -> None:
"""Initialize our settings instance.
:param bool sign:
Configure whether the package file should be signed.
This defaults to ``False``.
:param str sign_with:
The name of the executable used to sign the package with.
This defaults to ``gpg``.
:param str identity:
The GPG identity that should be used to sign the package file.
:param str username:
The username used to authenticate to the repository (package
index).
:param str password:
The password used to authenticate to the repository (package
index).
:param bool non_interactive:
Do not interactively prompt for username/password if the required
credentials are missing.
This defaults to ``False``.
:param str comment:
The comment to include with each distribution file.
:param str config_file:
The path to the configuration file to use.
This defaults to ``~/.pypirc``.
:param bool skip_existing:
Specify whether twine should continue uploading files if one
of them already exists. This primarily supports PyPI. Other
package indexes may not be supported.
This defaults to ``False``.
:param str cacert:
The path to the bundle of certificates used to verify the TLS
connection to the package index.
:param str client_cert:
The path to the client certificate used to perform authentication
to the index.
This must be a single file that contains both the private key and
the PEM-encoded certificate.
:param str repository_name:
The name of the repository (package index) to interact with. This
should correspond to a section in the config file.
:param str repository_url:
The URL of the repository (package index) to interact with. This
will override the settings inferred from ``repository_name``.
:param bool verbose:
Show verbose output.
:param bool disable_progress_bar:
Disable the progress bar.
This defaults to ``False``
"""
self.config_file = config_file
self.comment = comment
self.verbose = verbose
self.disable_progress_bar = disable_progress_bar
self.skip_existing = skip_existing
self._handle_repository_options(
repository_name=repository_name,
repository_url=repository_url,
)
self._handle_package_signing(
sign=sign,
sign_with=sign_with,
identity=identity,
)
# _handle_certificates relies on the parsed repository config
self._handle_certificates(cacert, client_cert)
self.auth = auth.Resolver.choose(not non_interactive)(
self.repository_config,
auth.CredentialInput(username, password),
)
@property
def username(self) -> Optional[str]:
# Workaround for https://github.com/python/mypy/issues/5858
return cast(Optional[str], self.auth.username)
@property
def password(self) -> Optional[str]:
with self._allow_noninteractive():
# Workaround for https://github.com/python/mypy/issues/5858
return cast(Optional[str], self.auth.password)
def _allow_noninteractive(self) -> ContextManager[None]:
"""Bypass NonInteractive error when client cert is present."""
suppressed = (exceptions.NonInteractive,) if self.client_cert else ()
return contextlib.suppress(*suppressed)
@property
def verbose(self) -> bool:
return self._verbose
@verbose.setter
def verbose(self, verbose: bool) -> None:
"""Initialize a logger based on the --verbose option."""
self._verbose = verbose
root_logger = logging.getLogger("twine")
root_logger.addHandler(logging.StreamHandler(sys.stdout))
root_logger.setLevel(logging.INFO if verbose else logging.WARNING)
@staticmethod
def register_argparse_arguments(parser: argparse.ArgumentParser) -> None:
"""Register the arguments for argparse."""
parser.add_argument(
"-r",
"--repository",
action=utils.EnvironmentDefault,
env="TWINE_REPOSITORY",
default="pypi",
help="The repository (package index) to upload the package to. "
"Should be a section in the config file (default: "
"%(default)s). (Can also be set via %(env)s environment "
"variable.)",
)
parser.add_argument(
"--repository-url",
action=utils.EnvironmentDefault,
env="TWINE_REPOSITORY_URL",
default=None,
required=False,
help="The repository (package index) URL to upload the package to."
" This overrides --repository. "
"(Can also be set via %(env)s environment variable.)",
)
parser.add_argument(
"-s",
"--sign",
action="store_true",
default=False,
help="Sign files to upload using GPG.",
)
parser.add_argument(
"--sign-with",
default="gpg",
help="GPG program used to sign uploads (default: %(default)s).",
)
parser.add_argument(
"-i",
"--identity",
help="GPG identity used to sign files.",
)
parser.add_argument(
"-u",
"--username",
action=utils.EnvironmentDefault,
env="TWINE_USERNAME",
required=False,
help="The username to authenticate to the repository "
"(package index) as. (Can also be set via "
"%(env)s environment variable.)",
)
parser.add_argument(
"-p",
"--password",
action=utils.EnvironmentDefault,
env="TWINE_PASSWORD",
required=False,
help="The password to authenticate to the repository "
"(package index) with. (Can also be set via "
"%(env)s environment variable.)",
)
parser.add_argument(
"--non-interactive",
action=utils.EnvironmentFlag,
env="TWINE_NON_INTERACTIVE",
help="Do not interactively prompt for username/password if the "
"required credentials are missing. (Can also be set via "
"%(env)s environment variable.)",
)
parser.add_argument(
"-c",
"--comment",
help="The comment to include with the distribution file.",
)
parser.add_argument(
"--config-file",
default="~/.pypirc",
help="The .pypirc config file to use.",
)
parser.add_argument(
"--skip-existing",
default=False,
action="store_true",
help="Continue uploading files if one already exists. (Only valid "
"when uploading to PyPI. Other implementations may not "
"support this.)",
)
parser.add_argument(
"--cert",
action=utils.EnvironmentDefault,
env="TWINE_CERT",
default=None,
required=False,
metavar="path",
help="Path to alternate CA bundle (can also be set via %(env)s "
"environment variable).",
)
parser.add_argument(
"--client-cert",
metavar="path",
help="Path to SSL client certificate, a single file containing the"
" private key and the certificate in PEM format.",
)
parser.add_argument(
"--verbose",
default=False,
required=False,
action="store_true",
help="Show verbose output.",
)
parser.add_argument(
"--disable-progress-bar",
default=False,
required=False,
action="store_true",
help="Disable the progress bar.",
)
@classmethod
def from_argparse(cls, args: argparse.Namespace) -> "Settings":
"""Generate the Settings from parsed arguments."""
settings = vars(args)
settings["repository_name"] = settings.pop("repository")
settings["cacert"] = settings.pop("cert")
return cls(**settings)
def _handle_package_signing(
self, sign: bool, sign_with: str, identity: Optional[str]
) -> None:
if not sign and identity:
raise exceptions.InvalidSigningConfiguration(
"sign must be given along with identity"
)
self.sign = sign
self.sign_with = sign_with
self.identity = identity
def _handle_repository_options(
self, repository_name: str, repository_url: Optional[str]
) -> None:
self.repository_config = utils.get_repository_from_config(
self.config_file,
repository_name,
repository_url,
)
self.repository_config["repository"] = utils.normalize_repository_url(
cast(str, self.repository_config["repository"]),
)
def _handle_certificates(
self, cacert: Optional[str], client_cert: Optional[str]
) -> None:
self.cacert = utils.get_cacert(cacert, self.repository_config)
self.client_cert = utils.get_clientcert(client_cert, self.repository_config)
def check_repository_url(self) -> None:
"""Verify we are not using legacy PyPI.
:raises:
:class:`~twine.exceptions.UploadToDeprecatedPyPIDetected`
"""
repository_url = cast(str, self.repository_config["repository"])
if repository_url.startswith(
(repository.LEGACY_PYPI, repository.LEGACY_TEST_PYPI)
):
raise exceptions.UploadToDeprecatedPyPIDetected.from_args(
repository_url, utils.DEFAULT_REPOSITORY, utils.TEST_REPOSITORY
)
def create_repository(self) -> repository.Repository:
"""Create a new repository for uploading."""
repo = repository.Repository(
cast(str, self.repository_config["repository"]),
self.username,
self.password,
self.disable_progress_bar,
)
repo.set_certificate_authority(self.cacert)
repo.set_client_certificate(self.client_cert)
return repo

@ -0,0 +1,297 @@
# Copyright 2013 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import collections
import configparser
import functools
import logging
import os
import os.path
from typing import Any, Callable, DefaultDict, Dict, Optional, Sequence, Union
from urllib.parse import urlparse
from urllib.parse import urlunparse
import requests
import rfc3986
from twine import exceptions
# Shim for input to allow testing.
input_func = input
DEFAULT_REPOSITORY = "https://upload.pypi.org/legacy/"
TEST_REPOSITORY = "https://test.pypi.org/legacy/"
# TODO: In general, it seems to be assumed that the values retrieved from
# instances of this type aren't None, except for username and password.
# Type annotations would be cleaner if this were Dict[str, str], but that
# requires reworking the username/password handling, probably starting with
# get_userpass_value.
RepositoryConfig = Dict[str, Optional[str]]
logger = logging.getLogger(__name__)
def get_config(path: str = "~/.pypirc") -> Dict[str, RepositoryConfig]:
# even if the config file does not exist, set up the parser
# variable to reduce the number of if/else statements
parser = configparser.RawConfigParser()
# this list will only be used if index-servers
# is not defined in the config file
index_servers = ["pypi", "testpypi"]
# default configuration for each repository
defaults: RepositoryConfig = {"username": None, "password": None}
# Expand user strings in the path
path = os.path.expanduser(path)
logger.info(f"Using configuration from {path}")
# Parse the rc file
if os.path.isfile(path):
parser.read(path)
# Get a list of index_servers from the config file
# format: https://packaging.python.org/specifications/pypirc/
if parser.has_option("distutils", "index-servers"):
index_servers = parser.get("distutils", "index-servers").split()
for key in ["username", "password"]:
if parser.has_option("server-login", key):
defaults[key] = parser.get("server-login", key)
config: DefaultDict[str, RepositoryConfig] = collections.defaultdict(
lambda: defaults.copy()
)
# don't require users to manually configure URLs for these repositories
config["pypi"]["repository"] = DEFAULT_REPOSITORY
if "testpypi" in index_servers:
config["testpypi"]["repository"] = TEST_REPOSITORY
# optional configuration values for individual repositories
for repository in index_servers:
for key in [
"username",
"repository",
"password",
"ca_cert",
"client_cert",
]:
if parser.has_option(repository, key):
config[repository][key] = parser.get(repository, key)
# convert the defaultdict to a regular dict at this point
# to prevent surprising behavior later on
return dict(config)
def _validate_repository_url(repository_url: str) -> None:
"""Validate the given url for allowed schemes and components."""
# Allowed schemes are http and https, based on whether the repository
# supports TLS or not, and scheme and host must be present in the URL
validator = (
rfc3986.validators.Validator()
.allow_schemes("http", "https")
.require_presence_of("scheme", "host")
)
try:
validator.validate(rfc3986.uri_reference(repository_url))
except rfc3986.exceptions.RFC3986Exception as exc:
raise exceptions.UnreachableRepositoryURLDetected(
f"Invalid repository URL: {exc.args[0]}."
)
def get_repository_from_config(
config_file: str, repository: str, repository_url: Optional[str] = None
) -> RepositoryConfig:
# Get our config from, if provided, command-line values for the
# repository name and URL, or the .pypirc file
if repository_url:
_validate_repository_url(repository_url)
# prefer CLI `repository_url` over `repository` or .pypirc
return {
"repository": repository_url,
"username": None,
"password": None,
}
try:
return get_config(config_file)[repository]
except KeyError:
msg = (
"Missing '{repo}' section from the configuration file\n"
"or not a complete URL in --repository-url.\n"
"Maybe you have an out-dated '{cfg}' format?\n"
"more info: "
"https://packaging.python.org/specifications/pypirc/\n"
).format(repo=repository, cfg=config_file)
raise exceptions.InvalidConfiguration(msg)
_HOSTNAMES = {
"pypi.python.org",
"testpypi.python.org",
"upload.pypi.org",
"test.pypi.org",
}
def normalize_repository_url(url: str) -> str:
parsed = urlparse(url)
if parsed.netloc in _HOSTNAMES:
return urlunparse(("https",) + parsed[1:])
return urlunparse(parsed)
def get_file_size(filename: str) -> str:
"""Return the size of a file in KB, or MB if >= 1024 KB."""
file_size = os.path.getsize(filename) / 1024
size_unit = "KB"
if file_size > 1024:
file_size = file_size / 1024
size_unit = "MB"
return f"{file_size:.1f} {size_unit}"
def check_status_code(response: requests.Response, verbose: bool) -> None:
"""Generate a helpful message based on the response from the repository.
Raise a custom exception for recognized errors. Otherwise, print the
response content (based on the verbose option) before re-raising the
HTTPError.
"""
if response.status_code == 410 and "pypi.python.org" in response.url:
raise exceptions.UploadToDeprecatedPyPIDetected(
f"It appears you're uploading to pypi.python.org (or "
f"testpypi.python.org). You've received a 410 error response. "
f"Uploading to those sites is deprecated. The new sites are "
f"pypi.org and test.pypi.org. Try using {DEFAULT_REPOSITORY} (or "
f"{TEST_REPOSITORY}) to upload your packages instead. These are "
f"the default URLs for Twine now. More at "
f"https://packaging.python.org/guides/migrating-to-pypi-org/."
)
elif response.status_code == 405 and "pypi.org" in response.url:
raise exceptions.InvalidPyPIUploadURL(
f"It appears you're trying to upload to pypi.org but have an "
f"invalid URL. You probably want one of these two URLs: "
f"{DEFAULT_REPOSITORY} or {TEST_REPOSITORY}. Check your "
f"--repository-url value."
)
try:
response.raise_for_status()
except requests.HTTPError as err:
if response.text:
logger.info("Content received from server:\n{}".format(response.text))
if not verbose:
logger.warning("NOTE: Try --verbose to see response content.")
raise err
def get_userpass_value(
cli_value: Optional[str],
config: RepositoryConfig,
key: str,
prompt_strategy: Optional[Callable[[], str]] = None,
) -> Optional[str]:
"""Get the username / password from config.
Uses the following rules:
1. If it is specified on the cli (`cli_value`), use that.
2. If `config[key]` is specified, use that.
3. If `prompt_strategy`, prompt using `prompt_strategy`.
4. Otherwise return None
:param cli_value: The value supplied from the command line or `None`.
:type cli_value: unicode or `None`
:param config: Config dictionary
:type config: dict
:param key: Key to find the config value.
:type key: unicode
:prompt_strategy: Argumentless function to return fallback value.
:type prompt_strategy: function
:returns: The value for the username / password
:rtype: unicode
"""
if cli_value is not None:
logger.info(f"{key} set by command options")
return cli_value
elif config.get(key) is not None:
logger.info(f"{key} set from config file")
return config[key]
elif prompt_strategy:
return prompt_strategy()
else:
return None
get_cacert = functools.partial(get_userpass_value, key="ca_cert")
get_clientcert = functools.partial(get_userpass_value, key="client_cert")
class EnvironmentDefault(argparse.Action):
"""Get values from environment variable."""
def __init__(
self,
env: str,
required: bool = True,
default: Optional[str] = None,
**kwargs: Any,
) -> None:
default = os.environ.get(env, default)
self.env = env
if default:
required = False
super().__init__(default=default, required=required, **kwargs)
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Union[str, Sequence[Any], None],
option_string: Optional[str] = None,
) -> None:
setattr(namespace, self.dest, values)
class EnvironmentFlag(argparse.Action):
"""Set boolean flag from environment variable."""
def __init__(self, env: str, **kwargs: Any) -> None:
default = self.bool_from_env(os.environ.get(env))
self.env = env
super().__init__(default=default, nargs=0, **kwargs)
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Union[str, Sequence[Any], None],
option_string: Optional[str] = None,
) -> None:
setattr(namespace, self.dest, True)
@staticmethod
def bool_from_env(val: Optional[str]) -> bool:
"""Allow '0' and 'false' and 'no' to be False."""
falsey = {"0", "false", "no"}
return bool(val and val.lower() not in falsey)

@ -0,0 +1,91 @@
# Copyright 2013 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import re
import zipfile
from typing import List, Optional
from pkginfo import distribution
from twine import exceptions
# Monkeypatch Metadata 2.0 support
distribution.HEADER_ATTRS_2_0 = distribution.HEADER_ATTRS_1_2
distribution.HEADER_ATTRS.update({"2.0": distribution.HEADER_ATTRS_2_0})
wheel_file_re = re.compile(
r"""^(?P<namever>(?P<name>.+?)(-(?P<ver>\d.+?))?)
((-(?P<build>\d.*?))?-(?P<pyver>.+?)-(?P<abi>.+?)-(?P<plat>.+?)
\.whl|\.dist-info)$""",
re.VERBOSE,
)
class Wheel(distribution.Distribution):
def __init__(self, filename: str, metadata_version: Optional[str] = None) -> None:
self.filename = filename
self.basefilename = os.path.basename(self.filename)
self.metadata_version = metadata_version
self.extractMetadata()
@property
def py_version(self) -> str:
wheel_info = wheel_file_re.match(self.basefilename)
if wheel_info is None:
return "any"
else:
return wheel_info.group("pyver")
@staticmethod
def find_candidate_metadata_files(names: List[str]) -> List[List[str]]:
"""Filter files that may be METADATA files."""
tuples = [x.split("/") for x in names if "METADATA" in x]
return [x[1] for x in sorted([(len(x), x) for x in tuples])]
def read(self) -> bytes:
fqn = os.path.abspath(os.path.normpath(self.filename))
if not os.path.exists(fqn):
raise exceptions.InvalidDistribution("No such file: %s" % fqn)
if fqn.endswith(".whl"):
archive = zipfile.ZipFile(fqn)
names = archive.namelist()
def read_file(name: str) -> bytes:
return archive.read(name)
else:
raise exceptions.InvalidDistribution(
"Not a known archive format for file: %s" % fqn
)
try:
for path in self.find_candidate_metadata_files(names):
candidate = "/".join(path)
data = read_file(candidate)
if b"Metadata-Version" in data:
return data
finally:
archive.close()
raise exceptions.InvalidDistribution("No METADATA in archive: %s" % fqn)
def parse(self, data: bytes) -> None:
super().parse(data)
fp = io.StringIO(distribution.must_decode(data))
msg = distribution.parse(fp)
self.description = msg.get_payload()

@ -0,0 +1,61 @@
import os
import re
import zipfile
from typing import Optional
from pkginfo import distribution
from twine import exceptions
wininst_file_re = re.compile(r".*py(?P<pyver>\d+\.\d+)\.exe$")
class WinInst(distribution.Distribution):
def __init__(self, filename: str, metadata_version: Optional[str] = None) -> None:
self.filename = filename
self.metadata_version = metadata_version
self.extractMetadata()
@property
def py_version(self) -> str:
m = wininst_file_re.match(self.filename)
if m is None:
return "any"
else:
return m.group("pyver")
def read(self) -> bytes:
fqn = os.path.abspath(os.path.normpath(self.filename))
if not os.path.exists(fqn):
raise exceptions.InvalidDistribution("No such file: %s" % fqn)
if fqn.endswith(".exe"):
archive = zipfile.ZipFile(fqn)
names = archive.namelist()
def read_file(name: str) -> bytes:
return archive.read(name)
else:
raise exceptions.InvalidDistribution(
"Not a known archive format for file: %s" % fqn
)
try:
tuples = [
x.split("/")
for x in names
if x.endswith(".egg-info") or x.endswith("PKG-INFO")
]
schwarz = sorted([(len(x), x) for x in tuples])
for path in [x[1] for x in schwarz]:
candidate = "/".join(path)
data = read_file(candidate)
if b"Metadata-Version" in data:
return data
finally:
archive.close()
raise exceptions.InvalidDistribution(
"No PKG-INFO/.egg-info in archive: %s" % fqn
)

@ -13,11 +13,13 @@ enzyme=0.4.1
ffsubsync=0.4.11
Flask=1.1.1
flask-socketio=5.0.2dev
gevent-websocker=0.10.1
gitpython=2.1.9
guessit=3.3.1
guess_language-spirit=0.5.3
Js2Py=0.63 <-- modified: manually merged from upstream: https://github.com/PiotrDabkowski/Js2Py/pull/192/files
knowit=0.3.0-dev
msgpack=1.0.2
py-pretty=1
pycountry=18.2.23
pyga=2.6.1
@ -27,13 +29,18 @@ rarfile=3.0
rebulk=3.0.1
requests=2.18.4
semver=2.13.0
signalr-client=0.0.7 <-- Modified to work with Sonarr
signalrcore=0.9.2
SimpleConfigParser=0.1.0 <-- modified version: do not update!!!
six=1.11.0
socketio=5.1.0
sseclient=0.0.27 <-- Modified to work with Sonarr
stevedore=1.28.0
subliminal=2.1.0dev
tzlocal=2.1b1
twine=3.4.1
urllib3=1.23
websocket-client=0.54.0
## indirect dependencies
auditok=0.1.5 # Required-by: ffsubsync

@ -26,4 +26,4 @@ from ._exceptions import *
from ._logging import *
from ._socket import *
__version__ = "0.44.0"
__version__ = "0.54.0"

@ -30,12 +30,20 @@ from ._utils import validate_utf8
from threading import Lock
try:
# If wsaccel is available we use compiled routines to mask data.
from wsaccel.xormask import XorMaskerSimple
if six.PY3:
import numpy
else:
numpy = None
except ImportError:
numpy = None
def _mask(_m, _d):
return XorMaskerSimple(_m).process(_d)
try:
# If wsaccel is available we use compiled routines to mask data.
if not numpy:
from wsaccel.xormask import XorMaskerSimple
def _mask(_m, _d):
return XorMaskerSimple(_m).process(_d)
except ImportError:
# wsaccel is not available, we rely on python implementations.
def _mask(_m, _d):
@ -47,6 +55,7 @@ except ImportError:
else:
return _d.tostring()
__all__ = [
'ABNF', 'continuous_frame', 'frame_buffer',
'STATUS_NORMAL',
@ -258,9 +267,21 @@ class ABNF(object):
if isinstance(data, six.text_type):
data = six.b(data)
_m = array.array("B", mask_key)
_d = array.array("B", data)
return _mask(_m, _d)
if numpy:
origlen = len(data)
_mask_key = mask_key[3] << 24 | mask_key[2] << 16 | mask_key[1] << 8 | mask_key[0]
# We need data to be a multiple of four...
data += bytes(" " * (4 - (len(data) % 4)), "us-ascii")
a = numpy.frombuffer(data, dtype="uint32")
masked = numpy.bitwise_xor(a, [_mask_key]).astype("uint32")
if len(data) > origlen:
return masked.tobytes()[:origlen]
return masked.tobytes()
else:
_m = array.array("B", mask_key)
_d = array.array("B", data)
return _mask(_m, _d)
class frame_buffer(object):

@ -23,6 +23,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
"""
WebSocketApp provides higher level APIs.
"""
import inspect
import select
import sys
import threading
@ -39,6 +40,40 @@ from . import _logging
__all__ = ["WebSocketApp"]
class Dispatcher:
def __init__(self, app, ping_timeout):
self.app = app
self.ping_timeout = ping_timeout
def read(self, sock, read_callback, check_callback):
while self.app.sock.connected:
r, w, e = select.select(
(self.app.sock.sock, ), (), (), self.ping_timeout)
if r:
if not read_callback():
break
check_callback()
class SSLDispacther:
def __init__(self, app, ping_timeout):
self.app = app
self.ping_timeout = ping_timeout
def read(self, sock, read_callback, check_callback):
while self.app.sock.connected:
r = self.select()
if r:
if not read_callback():
break
check_callback()
def select(self):
sock = self.app.sock.sock
if sock.pending():
return [sock,]
r, w, e = select.select((sock, ), (), (), self.ping_timeout)
return r
class WebSocketApp(object):
"""
@ -83,8 +118,7 @@ class WebSocketApp(object):
The 2nd argument is utf-8 string which we get from the server.
The 3rd argument is data type. ABNF.OPCODE_TEXT or ABNF.OPCODE_BINARY will be came.
The 4th argument is continue flag. if 0, the data continue
keep_running: a boolean flag indicating whether the app's main loop
should keep running, defaults to True
keep_running: this parameter is obsolete and ignored.
get_mask_key: a callable to produce new mask keys,
see the WebSocket.set_mask_key's docstring for more information
subprotocols: array of available sub protocols. default is None.
@ -92,6 +126,7 @@ class WebSocketApp(object):
self.url = url
self.header = header if header is not None else []
self.cookie = cookie
self.on_open = on_open
self.on_message = on_message
self.on_data = on_data
@ -100,7 +135,7 @@ class WebSocketApp(object):
self.on_ping = on_ping
self.on_pong = on_pong
self.on_cont_message = on_cont_message
self.keep_running = keep_running
self.keep_running = False
self.get_mask_key = get_mask_key
self.sock = None
self.last_ping_tm = 0
@ -126,6 +161,7 @@ class WebSocketApp(object):
self.keep_running = False
if self.sock:
self.sock.close(**kwargs)
self.sock = None
def _send_ping(self, interval, event):
while not event.wait(interval):
@ -142,7 +178,8 @@ class WebSocketApp(object):
http_proxy_host=None, http_proxy_port=None,
http_no_proxy=None, http_proxy_auth=None,
skip_utf8_validation=False,
host=None, origin=None):
host=None, origin=None, dispatcher=None,
suppress_origin = False, proxy_type=None):
"""
run event loop for WebSocket framework.
This loop is infinite loop and is alive during websocket is available.
@ -160,33 +197,64 @@ class WebSocketApp(object):
skip_utf8_validation: skip utf8 validation.
host: update host header.
origin: update origin header.
dispatcher: customize reading data from socket.
suppress_origin: suppress outputting origin header.
Returns
-------
False if caught KeyboardInterrupt
True if other exception was raised during a loop
"""
if not ping_timeout or ping_timeout <= 0:
if ping_timeout is not None and ping_timeout <= 0:
ping_timeout = None
if ping_timeout and ping_interval and ping_interval <= ping_timeout:
raise WebSocketException("Ensure ping_interval > ping_timeout")
if sockopt is None:
if not sockopt:
sockopt = []
if sslopt is None:
if not sslopt:
sslopt = {}
if self.sock:
raise WebSocketException("socket is already opened")
thread = None
close_frame = None
self.keep_running = True
self.last_ping_tm = 0
self.last_pong_tm = 0
def teardown(close_frame=None):
"""
Tears down the connection.
If close_frame is set, we will invoke the on_close handler with the
statusCode and reason from there.
"""
if thread and thread.isAlive():
event.set()
thread.join()
self.keep_running = False
if self.sock:
self.sock.close()
close_args = self._get_close_args(
close_frame.data if close_frame else None)
self._callback(self.on_close, *close_args)
self.sock = None
try:
self.sock = WebSocket(
self.get_mask_key, sockopt=sockopt, sslopt=sslopt,
fire_cont_frame=self.on_cont_message and True or False,
skip_utf8_validation=skip_utf8_validation)
fire_cont_frame=self.on_cont_message is not None,
skip_utf8_validation=skip_utf8_validation,
enable_multithread=True if ping_interval else False)
self.sock.settimeout(getdefaulttimeout())
self.sock.connect(
self.url, header=self.header, cookie=self.cookie,
http_proxy_host=http_proxy_host,
http_proxy_port=http_proxy_port, http_no_proxy=http_no_proxy,
http_proxy_auth=http_proxy_auth, subprotocols=self.subprotocols,
host=host, origin=origin)
host=host, origin=origin, suppress_origin=suppress_origin,
proxy_type=proxy_type)
if not dispatcher:
dispatcher = self.create_dispatcher(ping_timeout)
self._callback(self.on_open)
if ping_interval:
@ -196,58 +264,63 @@ class WebSocketApp(object):
thread.setDaemon(True)
thread.start()
while self.sock.connected:
r, w, e = select.select(
(self.sock.sock, ), (), (), ping_timeout or 10) # Use a 10 second timeout to avoid to wait forever on close
def read():
if not self.keep_running:
break
return teardown()
if r:
op_code, frame = self.sock.recv_data_frame(True)
if op_code == ABNF.OPCODE_CLOSE:
close_frame = frame
break
elif op_code == ABNF.OPCODE_PING:
self._callback(self.on_ping, frame.data)
elif op_code == ABNF.OPCODE_PONG:
self.last_pong_tm = time.time()
self._callback(self.on_pong, frame.data)
elif op_code == ABNF.OPCODE_CONT and self.on_cont_message:
self._callback(self.on_data, data,
frame.opcode, frame.fin)
self._callback(self.on_cont_message,
frame.data, frame.fin)
else:
data = frame.data
if six.PY3 and op_code == ABNF.OPCODE_TEXT:
data = data.decode("utf-8")
self._callback(self.on_data, data, frame.opcode, True)
self._callback(self.on_message, data)
if ping_timeout and self.last_ping_tm \
and time.time() - self.last_ping_tm > ping_timeout \
and self.last_ping_tm - self.last_pong_tm > ping_timeout:
raise WebSocketTimeoutException("ping/pong timed out")
op_code, frame = self.sock.recv_data_frame(True)
if op_code == ABNF.OPCODE_CLOSE:
return teardown(frame)
elif op_code == ABNF.OPCODE_PING:
self._callback(self.on_ping, frame.data)
elif op_code == ABNF.OPCODE_PONG:
self.last_pong_tm = time.time()
self._callback(self.on_pong, frame.data)
elif op_code == ABNF.OPCODE_CONT and self.on_cont_message:
self._callback(self.on_data, frame.data,
frame.opcode, frame.fin)
self._callback(self.on_cont_message,
frame.data, frame.fin)
else:
data = frame.data
if six.PY3 and op_code == ABNF.OPCODE_TEXT:
data = data.decode("utf-8")
self._callback(self.on_data, data, frame.opcode, True)
self._callback(self.on_message, data)
return True
def check():
if (ping_timeout):
has_timeout_expired = time.time() - self.last_ping_tm > ping_timeout
has_pong_not_arrived_after_last_ping = self.last_pong_tm - self.last_ping_tm < 0
has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > ping_timeout
if (self.last_ping_tm
and has_timeout_expired
and (has_pong_not_arrived_after_last_ping or has_pong_arrived_too_late)):
raise WebSocketTimeoutException("ping/pong timed out")
return True
dispatcher.read(self.sock.sock, read, check)
except (Exception, KeyboardInterrupt, SystemExit) as e:
self._callback(self.on_error, e)
if isinstance(e, SystemExit):
# propagate SystemExit further
raise
finally:
if thread and thread.isAlive():
event.set()
thread.join()
self.keep_running = False
self.sock.close()
close_args = self._get_close_args(
close_frame.data if close_frame else None)
self._callback(self.on_close, *close_args)
self.sock = None
teardown()
return not isinstance(e, KeyboardInterrupt)
def create_dispatcher(self, ping_timeout):
timeout = ping_timeout or 10
if self.sock.is_ssl():
return SSLDispacther(self, timeout)
return Dispatcher(self, timeout)
def _get_close_args(self, data):
""" this functions extracts the code, reason from the close body
if they exists, and if the self.on_close except three arguments """
import inspect
# if the on_close callback is "old", just return empty list
if sys.version_info < (3, 0):
if not self.on_close or len(inspect.getargspec(self.on_close).args) != 3:
@ -266,7 +339,11 @@ class WebSocketApp(object):
def _callback(self, callback, *args):
if callback:
try:
callback(self, *args)
if inspect.ismethod(callback):
callback(*args)
else:
callback(self, *args)
except Exception as e:
_logging.error("error from callback {}: {}".format(callback, e))
if _logging.isEnabledForDebug():

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save