diff --git a/apiclient/__init__.py b/apiclient/__init__.py deleted file mode 100644 index f901408..0000000 --- a/apiclient/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "1.1" diff --git a/apiclient/ext/__init__.py b/apiclient/ext/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/apiclient/push.py b/apiclient/push.py deleted file mode 100644 index c520faf..0000000 --- a/apiclient/push.py +++ /dev/null @@ -1,274 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Push notifications support. - -This code is based on experimental APIs and is subject to change. -""" - -__author__ = 'afshar@google.com (Ali Afshar)' - -import binascii -import collections -import os -import urllib - -SUBSCRIBE = 'X-GOOG-SUBSCRIBE' -SUBSCRIPTION_ID = 'X-GOOG-SUBSCRIPTION-ID' -TOPIC_ID = 'X-GOOG-TOPIC-ID' -TOPIC_URI = 'X-GOOG-TOPIC-URI' -CLIENT_TOKEN = 'X-GOOG-CLIENT-TOKEN' -EVENT_TYPE = 'X-GOOG-EVENT-TYPE' -UNSUBSCRIBE = 'X-GOOG-UNSUBSCRIBE' - - -class InvalidSubscriptionRequestError(ValueError): - """The request cannot be subscribed.""" - - -def new_token(): - """Gets a random token for use as a client_token in push notifications. - - Returns: - str, a new random token. - """ - return binascii.hexlify(os.urandom(32)) - - -class Channel(object): - """Base class for channel types.""" - - def __init__(self, channel_type, channel_args): - """Create a new Channel. - - You probably won't need to create this channel manually, since there are - subclassed Channel for each specific type with a more customized set of - arguments to pass. However, you may wish to just create it manually here. - - Args: - channel_type: str, the type of channel. - channel_args: dict, arguments to pass to the channel. - """ - self.channel_type = channel_type - self.channel_args = channel_args - - def as_header_value(self): - """Create the appropriate header for this channel. - - Returns: - str encoded channel description suitable for use as a header. - """ - return '%s?%s' % (self.channel_type, urllib.urlencode(self.channel_args)) - - def write_header(self, headers): - """Write the appropriate subscribe header to a headers dict. - - Args: - headers: dict, headers to add subscribe header to. - """ - headers[SUBSCRIBE] = self.as_header_value() - - -class WebhookChannel(Channel): - """Channel for registering web hook notifications.""" - - def __init__(self, url, app_engine=False): - """Create a new WebhookChannel - - Args: - url: str, URL to post notifications to. - app_engine: bool, default=False, whether the destination for the - notifications is an App Engine application. - """ - super(WebhookChannel, self).__init__( - channel_type='web_hook', - channel_args={ - 'url': url, - 'app_engine': app_engine and 'true' or 'false', - } - ) - - -class Headers(collections.defaultdict): - """Headers for managing subscriptions.""" - - - ALL_HEADERS = set([SUBSCRIBE, SUBSCRIPTION_ID, TOPIC_ID, TOPIC_URI, - CLIENT_TOKEN, EVENT_TYPE, UNSUBSCRIBE]) - - def __init__(self): - """Create a new subscription configuration instance.""" - collections.defaultdict.__init__(self, str) - - def __setitem__(self, key, value): - """Set a header value, ensuring the key is an allowed value. - - Args: - key: str, the header key. - value: str, the header value. - Raises: - ValueError if key is not one of the accepted headers. - """ - normal_key = self._normalize_key(key) - if normal_key not in self.ALL_HEADERS: - raise ValueError('Header name must be one of %s.' % self.ALL_HEADERS) - else: - return collections.defaultdict.__setitem__(self, normal_key, value) - - def __getitem__(self, key): - """Get a header value, normalizing the key case. - - Args: - key: str, the header key. - Returns: - String header value. - Raises: - KeyError if the key is not one of the accepted headers. - """ - normal_key = self._normalize_key(key) - if normal_key not in self.ALL_HEADERS: - raise ValueError('Header name must be one of %s.' % self.ALL_HEADERS) - else: - return collections.defaultdict.__getitem__(self, normal_key) - - def _normalize_key(self, key): - """Normalize a header name for use as a key.""" - return key.upper() - - def items(self): - """Generator for each header.""" - for header in self.ALL_HEADERS: - value = self[header] - if value: - yield header, value - - def write(self, headers): - """Applies the subscription headers. - - Args: - headers: dict of headers to insert values into. - """ - for header, value in self.items(): - headers[header.lower()] = value - - def read(self, headers): - """Read from headers. - - Args: - headers: dict of headers to read from. - """ - for header in self.ALL_HEADERS: - if header.lower() in headers: - self[header] = headers[header.lower()] - - -class Subscription(object): - """Information about a subscription.""" - - def __init__(self): - """Create a new Subscription.""" - self.headers = Headers() - - @classmethod - def for_request(cls, request, channel, client_token=None): - """Creates a subscription and attaches it to a request. - - Args: - request: An http.HttpRequest to modify for making a subscription. - channel: A apiclient.push.Channel describing the subscription to - create. - client_token: (optional) client token to verify the notification. - - Returns: - New subscription object. - """ - subscription = cls.for_channel(channel=channel, client_token=client_token) - subscription.headers.write(request.headers) - if request.method != 'GET': - raise InvalidSubscriptionRequestError( - 'Can only subscribe to requests which are GET.') - request.method = 'POST' - - def _on_response(response, subscription=subscription): - """Called with the response headers. Reads the subscription headers.""" - subscription.headers.read(response) - - request.add_response_callback(_on_response) - return subscription - - @classmethod - def for_channel(cls, channel, client_token=None): - """Alternate constructor to create a subscription from a channel. - - Args: - channel: A apiclient.push.Channel describing the subscription to - create. - client_token: (optional) client token to verify the notification. - - Returns: - New subscription object. - """ - subscription = cls() - channel.write_header(subscription.headers) - if client_token is None: - client_token = new_token() - subscription.headers[SUBSCRIPTION_ID] = new_token() - subscription.headers[CLIENT_TOKEN] = client_token - return subscription - - def verify(self, headers): - """Verifies that a webhook notification has the correct client_token. - - Args: - headers: dict of request headers for a push notification. - - Returns: - Boolean value indicating whether the notification is verified. - """ - new_subscription = Subscription() - new_subscription.headers.read(headers) - return new_subscription.client_token == self.client_token - - @property - def subscribe(self): - """Subscribe header value.""" - return self.headers[SUBSCRIBE] - - @property - def subscription_id(self): - """Subscription ID header value.""" - return self.headers[SUBSCRIPTION_ID] - - @property - def topic_id(self): - """Topic ID header value.""" - return self.headers[TOPIC_ID] - - @property - def topic_uri(self): - """Topic URI header value.""" - return self.headers[TOPIC_URI] - - @property - def client_token(self): - """Client Token header value.""" - return self.headers[CLIENT_TOKEN] - - @property - def event_type(self): - """Event Type header value.""" - return self.headers[EVENT_TYPE] - - @property - def unsubscribe(self): - """Unsuscribe header value.""" - return self.headers[UNSUBSCRIBE] diff --git a/build.bat b/build.bat index eb26193..0b121a4 100644 --- a/build.bat +++ b/build.bat @@ -1,19 +1,10 @@ rmdir /q /s gyb -rmdir /q /s gyb-64 rmdir /q /s build rmdir /q /s dist del /q /f gyb-%1-windows.zip -del /q /f gyb-%1-windows-x64.zip -c:\python27-32\scripts\pyinstaller --distpath=gyb gyb.spec +c:\python3\scripts\pyinstaller --distpath=gyb gyb.spec xcopy LICENSE gyb\ xcopy cacert.pem gyb\ xcopy client_secrets.json gyb\ -del gyb\w9xpopen.exe -"%ProgramFiles(x86)%\7-Zip\7z.exe" a -tzip gyb-%1-windows.zip gyb\ -xr!.svn - -c:\python27\scripts\pyinstaller --distpath=gyb-64 gyb.spec -xcopy LICENSE gyb-64\ -xcopy cacert.pem gyb-64\ -xcopy client_secrets.json gyb-64\ -"%ProgramFiles(x86)%\7-Zip\7z.exe" a -tzip gyb-%1-windows-x64.zip gyb-64\ -xr!.svn \ No newline at end of file +"%ProgramFiles(x86)%\7-Zip\7z.exe" a -tzip gyb-%1-windows.zip gyb\ -xr!.svn \ No newline at end of file diff --git a/gflags.py b/gflags.py deleted file mode 100644 index 822256a..0000000 --- a/gflags.py +++ /dev/null @@ -1,2862 +0,0 @@ -#!/usr/bin/env python -# -# Copyright (c) 2002, Google Inc. -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -# --- -# Author: Chad Lester -# Design and style contributions by: -# Amit Patel, Bogdan Cocosel, Daniel Dulitz, Eric Tiedemann, -# Eric Veach, Laurence Gonsalves, Matthew Springer -# Code reorganized a bit by Craig Silverstein - -"""This module is used to define and parse command line flags. - -This module defines a *distributed* flag-definition policy: rather than -an application having to define all flags in or near main(), each python -module defines flags that are useful to it. When one python module -imports another, it gains access to the other's flags. (This is -implemented by having all modules share a common, global registry object -containing all the flag information.) - -Flags are defined through the use of one of the DEFINE_xxx functions. -The specific function used determines how the flag is parsed, checked, -and optionally type-converted, when it's seen on the command line. - - -IMPLEMENTATION: DEFINE_* creates a 'Flag' object and registers it with a -'FlagValues' object (typically the global FlagValues FLAGS, defined -here). The 'FlagValues' object can scan the command line arguments and -pass flag arguments to the corresponding 'Flag' objects for -value-checking and type conversion. The converted flag values are -available as attributes of the 'FlagValues' object. - -Code can access the flag through a FlagValues object, for instance -gflags.FLAGS.myflag. Typically, the __main__ module passes the command -line arguments to gflags.FLAGS for parsing. - -At bottom, this module calls getopt(), so getopt functionality is -supported, including short- and long-style flags, and the use of -- to -terminate flags. - -Methods defined by the flag module will throw 'FlagsError' exceptions. -The exception argument will be a human-readable string. - - -FLAG TYPES: This is a list of the DEFINE_*'s that you can do. All flags -take a name, default value, help-string, and optional 'short' name -(one-letter name). Some flags have other arguments, which are described -with the flag. - -DEFINE_string: takes any input, and interprets it as a string. - -DEFINE_bool or -DEFINE_boolean: typically does not take an argument: say --myflag to - set FLAGS.myflag to true, or --nomyflag to set - FLAGS.myflag to false. Alternately, you can say - --myflag=true or --myflag=t or --myflag=1 or - --myflag=false or --myflag=f or --myflag=0 - -DEFINE_float: takes an input and interprets it as a floating point - number. Takes optional args lower_bound and upper_bound; - if the number specified on the command line is out of - range, it will raise a FlagError. - -DEFINE_integer: takes an input and interprets it as an integer. Takes - optional args lower_bound and upper_bound as for floats. - -DEFINE_enum: takes a list of strings which represents legal values. If - the command-line value is not in this list, raise a flag - error. Otherwise, assign to FLAGS.flag as a string. - -DEFINE_list: Takes a comma-separated list of strings on the commandline. - Stores them in a python list object. - -DEFINE_spaceseplist: Takes a space-separated list of strings on the - commandline. Stores them in a python list object. - Example: --myspacesepflag "foo bar baz" - -DEFINE_multistring: The same as DEFINE_string, except the flag can be - specified more than once on the commandline. The - result is a python list object (list of strings), - even if the flag is only on the command line once. - -DEFINE_multi_int: The same as DEFINE_integer, except the flag can be - specified more than once on the commandline. The - result is a python list object (list of ints), even if - the flag is only on the command line once. - - -SPECIAL FLAGS: There are a few flags that have special meaning: - --help prints a list of all the flags in a human-readable fashion - --helpshort prints a list of all key flags (see below). - --helpxml prints a list of all flags, in XML format. DO NOT parse - the output of --help and --helpshort. Instead, parse - the output of --helpxml. For more info, see - "OUTPUT FOR --helpxml" below. - --flagfile=foo read flags from file foo. - --undefok=f1,f2 ignore unrecognized option errors for f1,f2. - For boolean flags, you should use --undefok=boolflag, and - --boolflag and --noboolflag will be accepted. Do not use - --undefok=noboolflag. - -- as in getopt(), terminates flag-processing - - -FLAGS VALIDATORS: If your program: - - requires flag X to be specified - - needs flag Y to match a regular expression - - or requires any more general constraint to be satisfied -then validators are for you! - -Each validator represents a constraint over one flag, which is enforced -starting from the initial parsing of the flags and until the program -terminates. - -Also, lower_bound and upper_bound for numerical flags are enforced using flag -validators. - -Howto: -If you want to enforce a constraint over one flag, use - -gflags.RegisterValidator(flag_name, - checker, - message='Flag validation failed', - flag_values=FLAGS) - -After flag values are initially parsed, and after any change to the specified -flag, method checker(flag_value) will be executed. If constraint is not -satisfied, an IllegalFlagValue exception will be raised. See -RegisterValidator's docstring for a detailed explanation on how to construct -your own checker. - - -EXAMPLE USAGE: - -FLAGS = gflags.FLAGS - -gflags.DEFINE_integer('my_version', 0, 'Version number.') -gflags.DEFINE_string('filename', None, 'Input file name', short_name='f') - -gflags.RegisterValidator('my_version', - lambda value: value % 2 == 0, - message='--my_version must be divisible by 2') -gflags.MarkFlagAsRequired('filename') - - -NOTE ON --flagfile: - -Flags may be loaded from text files in addition to being specified on -the commandline. - -Any flags you don't feel like typing, throw them in a file, one flag per -line, for instance: - --myflag=myvalue - --nomyboolean_flag -You then specify your file with the special flag '--flagfile=somefile'. -You CAN recursively nest flagfile= tokens OR use multiple files on the -command line. Lines beginning with a single hash '#' or a double slash -'//' are comments in your flagfile. - -Any flagfile= will be interpreted as having a relative path from -the current working directory rather than from the place the file was -included from: - myPythonScript.py --flagfile=config/somefile.cfg - -If somefile.cfg includes further --flagfile= directives, these will be -referenced relative to the original CWD, not from the directory the -including flagfile was found in! - -The caveat applies to people who are including a series of nested files -in a different dir than they are executing out of. Relative path names -are always from CWD, not from the directory of the parent include -flagfile. We do now support '~' expanded directory names. - -Absolute path names ALWAYS work! - - -EXAMPLE USAGE: - - - FLAGS = gflags.FLAGS - - # Flag names are globally defined! So in general, we need to be - # careful to pick names that are unlikely to be used by other libraries. - # If there is a conflict, we'll get an error at import time. - gflags.DEFINE_string('name', 'Mr. President', 'your name') - gflags.DEFINE_integer('age', None, 'your age in years', lower_bound=0) - gflags.DEFINE_boolean('debug', False, 'produces debugging output') - gflags.DEFINE_enum('gender', 'male', ['male', 'female'], 'your gender') - - def main(argv): - try: - argv = FLAGS(argv) # parse flags - except gflags.FlagsError, e: - print '%s\\nUsage: %s ARGS\\n%s' % (e, sys.argv[0], FLAGS) - sys.exit(1) - if FLAGS.debug: print 'non-flag arguments:', argv - print 'Happy Birthday', FLAGS.name - if FLAGS.age is not None: - print 'You are a %d year old %s' % (FLAGS.age, FLAGS.gender) - - if __name__ == '__main__': - main(sys.argv) - - -KEY FLAGS: - -As we already explained, each module gains access to all flags defined -by all the other modules it transitively imports. In the case of -non-trivial scripts, this means a lot of flags ... For documentation -purposes, it is good to identify the flags that are key (i.e., really -important) to a module. Clearly, the concept of "key flag" is a -subjective one. When trying to determine whether a flag is key to a -module or not, assume that you are trying to explain your module to a -potential user: which flags would you really like to mention first? - -We'll describe shortly how to declare which flags are key to a module. -For the moment, assume we know the set of key flags for each module. -Then, if you use the app.py module, you can use the --helpshort flag to -print only the help for the flags that are key to the main module, in a -human-readable format. - -NOTE: If you need to parse the flag help, do NOT use the output of ---help / --helpshort. That output is meant for human consumption, and -may be changed in the future. Instead, use --helpxml; flags that are -key for the main module are marked there with a yes element. - -The set of key flags for a module M is composed of: - -1. Flags defined by module M by calling a DEFINE_* function. - -2. Flags that module M explictly declares as key by using the function - - DECLARE_key_flag() - -3. Key flags of other modules that M specifies by using the function - - ADOPT_module_key_flags() - - This is a "bulk" declaration of key flags: each flag that is key for - becomes key for the current module too. - -Notice that if you do not use the functions described at points 2 and 3 -above, then --helpshort prints information only about the flags defined -by the main module of our script. In many cases, this behavior is good -enough. But if you move part of the main module code (together with the -related flags) into a different module, then it is nice to use -DECLARE_key_flag / ADOPT_module_key_flags and make sure --helpshort -lists all relevant flags (otherwise, your code refactoring may confuse -your users). - -Note: each of DECLARE_key_flag / ADOPT_module_key_flags has its own -pluses and minuses: DECLARE_key_flag is more targeted and may lead a -more focused --helpshort documentation. ADOPT_module_key_flags is good -for cases when an entire module is considered key to the current script. -Also, it does not require updates to client scripts when a new flag is -added to the module. - - -EXAMPLE USAGE 2 (WITH KEY FLAGS): - -Consider an application that contains the following three files (two -auxiliary modules and a main module) - -File libfoo.py: - - import gflags - - gflags.DEFINE_integer('num_replicas', 3, 'Number of replicas to start') - gflags.DEFINE_boolean('rpc2', True, 'Turn on the usage of RPC2.') - - ... some code ... - -File libbar.py: - - import gflags - - gflags.DEFINE_string('bar_gfs_path', '/gfs/path', - 'Path to the GFS files for libbar.') - gflags.DEFINE_string('email_for_bar_errors', 'bar-team@google.com', - 'Email address for bug reports about module libbar.') - gflags.DEFINE_boolean('bar_risky_hack', False, - 'Turn on an experimental and buggy optimization.') - - ... some code ... - -File myscript.py: - - import gflags - import libfoo - import libbar - - gflags.DEFINE_integer('num_iterations', 0, 'Number of iterations.') - - # Declare that all flags that are key for libfoo are - # key for this module too. - gflags.ADOPT_module_key_flags(libfoo) - - # Declare that the flag --bar_gfs_path (defined in libbar) is key - # for this module. - gflags.DECLARE_key_flag('bar_gfs_path') - - ... some code ... - -When myscript is invoked with the flag --helpshort, the resulted help -message lists information about all the key flags for myscript: ---num_iterations, --num_replicas, --rpc2, and --bar_gfs_path. - -Of course, myscript uses all the flags declared by it (in this case, -just --num_replicas) or by any of the modules it transitively imports -(e.g., the modules libfoo, libbar). E.g., it can access the value of -FLAGS.bar_risky_hack, even if --bar_risky_hack is not declared as a key -flag for myscript. - - -OUTPUT FOR --helpxml: - -The --helpxml flag generates output with the following structure: - - - - PROGRAM_BASENAME - MAIN_MODULE_DOCSTRING - ( - [yes] - DECLARING_MODULE - FLAG_NAME - FLAG_HELP_MESSAGE - DEFAULT_FLAG_VALUE - CURRENT_FLAG_VALUE - FLAG_TYPE - [OPTIONAL_ELEMENTS] - )* - - -Notes: - -1. The output is intentionally similar to the output generated by the -C++ command-line flag library. The few differences are due to the -Python flags that do not have a C++ equivalent (at least not yet), -e.g., DEFINE_list. - -2. New XML elements may be added in the future. - -3. DEFAULT_FLAG_VALUE is in serialized form, i.e., the string you can -pass for this flag on the command-line. E.g., for a flag defined -using DEFINE_list, this field may be foo,bar, not ['foo', 'bar']. - -4. CURRENT_FLAG_VALUE is produced using str(). This means that the -string 'false' will be represented in the same way as the boolean -False. Using repr() would have removed this ambiguity and simplified -parsing, but would have broken the compatibility with the C++ -command-line flags. - -5. OPTIONAL_ELEMENTS describe elements relevant for certain kinds of -flags: lower_bound, upper_bound (for flags that specify bounds), -enum_value (for enum flags), list_separator (for flags that consist of -a list of values, separated by a special token). - -6. We do not provide any example here: please use --helpxml instead. - -This module requires at least python 2.2.1 to run. -""" - -import cgi -import getopt -import os -import re -import string -import struct -import sys -# pylint: disable-msg=C6204 -try: - import fcntl -except ImportError: - fcntl = None -try: - # Importing termios will fail on non-unix platforms. - import termios -except ImportError: - termios = None - -import gflags_validators -# pylint: enable-msg=C6204 - - -# Are we running under pychecker? -_RUNNING_PYCHECKER = 'pychecker.python' in sys.modules - - -def _GetCallingModuleObjectAndName(): - """Returns the module that's calling into this module. - - We generally use this function to get the name of the module calling a - DEFINE_foo... function. - """ - # Walk down the stack to find the first globals dict that's not ours. - for depth in range(1, sys.getrecursionlimit()): - if not sys._getframe(depth).f_globals is globals(): - globals_for_frame = sys._getframe(depth).f_globals - module, module_name = _GetModuleObjectAndName(globals_for_frame) - if module_name is not None: - return module, module_name - raise AssertionError("No module was found") - - -def _GetCallingModule(): - """Returns the name of the module that's calling into this module.""" - return _GetCallingModuleObjectAndName()[1] - - -def _GetThisModuleObjectAndName(): - """Returns: (module object, module name) for this module.""" - return _GetModuleObjectAndName(globals()) - - -# module exceptions: -class FlagsError(Exception): - """The base class for all flags errors.""" - pass - - -class DuplicateFlag(FlagsError): - """Raised if there is a flag naming conflict.""" - pass - -class CantOpenFlagFileError(FlagsError): - """Raised if flagfile fails to open: doesn't exist, wrong permissions, etc.""" - pass - - -class DuplicateFlagCannotPropagateNoneToSwig(DuplicateFlag): - """Special case of DuplicateFlag -- SWIG flag value can't be set to None. - - This can be raised when a duplicate flag is created. Even if allow_override is - True, we still abort if the new value is None, because it's currently - impossible to pass None default value back to SWIG. See FlagValues.SetDefault - for details. - """ - pass - - -class DuplicateFlagError(DuplicateFlag): - """A DuplicateFlag whose message cites the conflicting definitions. - - A DuplicateFlagError conveys more information than a DuplicateFlag, - namely the modules where the conflicting definitions occur. This - class was created to avoid breaking external modules which depend on - the existing DuplicateFlags interface. - """ - - def __init__(self, flagname, flag_values, other_flag_values=None): - """Create a DuplicateFlagError. - - Args: - flagname: Name of the flag being redefined. - flag_values: FlagValues object containing the first definition of - flagname. - other_flag_values: If this argument is not None, it should be the - FlagValues object where the second definition of flagname occurs. - If it is None, we assume that we're being called when attempting - to create the flag a second time, and we use the module calling - this one as the source of the second definition. - """ - self.flagname = flagname - first_module = flag_values.FindModuleDefiningFlag( - flagname, default='') - if other_flag_values is None: - second_module = _GetCallingModule() - else: - second_module = other_flag_values.FindModuleDefiningFlag( - flagname, default='') - msg = "The flag '%s' is defined twice. First from %s, Second from %s" % ( - self.flagname, first_module, second_module) - DuplicateFlag.__init__(self, msg) - - -class IllegalFlagValue(FlagsError): - """The flag command line argument is illegal.""" - pass - - -class UnrecognizedFlag(FlagsError): - """Raised if a flag is unrecognized.""" - pass - - -# An UnrecognizedFlagError conveys more information than an UnrecognizedFlag. -# Since there are external modules that create DuplicateFlags, the interface to -# DuplicateFlag shouldn't change. The flagvalue will be assigned the full value -# of the flag and its argument, if any, allowing handling of unrecognized flags -# in an exception handler. -# If flagvalue is the empty string, then this exception is an due to a -# reference to a flag that was not already defined. -class UnrecognizedFlagError(UnrecognizedFlag): - def __init__(self, flagname, flagvalue=''): - self.flagname = flagname - self.flagvalue = flagvalue - UnrecognizedFlag.__init__( - self, "Unknown command line flag '%s'" % flagname) - -# Global variable used by expvar -_exported_flags = {} -_help_width = 80 # width of help output - - -def GetHelpWidth(): - """Returns: an integer, the width of help lines that is used in TextWrap.""" - if (not sys.stdout.isatty()) or (termios is None) or (fcntl is None): - return _help_width - try: - data = fcntl.ioctl(sys.stdout, termios.TIOCGWINSZ, '1234') - columns = struct.unpack('hh', data)[1] - # Emacs mode returns 0. - # Here we assume that any value below 40 is unreasonable - if columns >= 40: - return columns - # Returning an int as default is fine, int(int) just return the int. - return int(os.getenv('COLUMNS', _help_width)) - - except (TypeError, IOError, struct.error): - return _help_width - - -def CutCommonSpacePrefix(text): - """Removes a common space prefix from the lines of a multiline text. - - If the first line does not start with a space, it is left as it is and - only in the remaining lines a common space prefix is being searched - for. That means the first line will stay untouched. This is especially - useful to turn doc strings into help texts. This is because some - people prefer to have the doc comment start already after the - apostrophe and then align the following lines while others have the - apostrophes on a separate line. - - The function also drops trailing empty lines and ignores empty lines - following the initial content line while calculating the initial - common whitespace. - - Args: - text: text to work on - - Returns: - the resulting text - """ - text_lines = text.splitlines() - # Drop trailing empty lines - while text_lines and not text_lines[-1]: - text_lines = text_lines[:-1] - if text_lines: - # We got some content, is the first line starting with a space? - if text_lines[0] and text_lines[0][0].isspace(): - text_first_line = [] - else: - text_first_line = [text_lines.pop(0)] - # Calculate length of common leading whitespace (only over content lines) - common_prefix = os.path.commonprefix([line for line in text_lines if line]) - space_prefix_len = len(common_prefix) - len(common_prefix.lstrip()) - # If we have a common space prefix, drop it from all lines - if space_prefix_len: - for index in xrange(len(text_lines)): - if text_lines[index]: - text_lines[index] = text_lines[index][space_prefix_len:] - return '\n'.join(text_first_line + text_lines) - return '' - - -def TextWrap(text, length=None, indent='', firstline_indent=None, tabs=' '): - """Wraps a given text to a maximum line length and returns it. - - We turn lines that only contain whitespace into empty lines. We keep - new lines and tabs (e.g., we do not treat tabs as spaces). - - Args: - text: text to wrap - length: maximum length of a line, includes indentation - if this is None then use GetHelpWidth() - indent: indent for all but first line - firstline_indent: indent for first line; if None, fall back to indent - tabs: replacement for tabs - - Returns: - wrapped text - - Raises: - FlagsError: if indent not shorter than length - FlagsError: if firstline_indent not shorter than length - """ - # Get defaults where callee used None - if length is None: - length = GetHelpWidth() - if indent is None: - indent = '' - if len(indent) >= length: - raise FlagsError('Indent must be shorter than length') - # In line we will be holding the current line which is to be started - # with indent (or firstline_indent if available) and then appended - # with words. - if firstline_indent is None: - firstline_indent = '' - line = indent - else: - line = firstline_indent - if len(firstline_indent) >= length: - raise FlagsError('First line indent must be shorter than length') - - # If the callee does not care about tabs we simply convert them to - # spaces If callee wanted tabs to be single space then we do that - # already here. - if not tabs or tabs == ' ': - text = text.replace('\t', ' ') - else: - tabs_are_whitespace = not tabs.strip() - - line_regex = re.compile('([ ]*)(\t*)([^ \t]+)', re.MULTILINE) - - # Split the text into lines and the lines with the regex above. The - # resulting lines are collected in result[]. For each split we get the - # spaces, the tabs and the next non white space (e.g. next word). - result = [] - for text_line in text.splitlines(): - # Store result length so we can find out whether processing the next - # line gave any new content - old_result_len = len(result) - # Process next line with line_regex. For optimization we do an rstrip(). - # - process tabs (changes either line or word, see below) - # - process word (first try to squeeze on line, then wrap or force wrap) - # Spaces found on the line are ignored, they get added while wrapping as - # needed. - for spaces, current_tabs, word in line_regex.findall(text_line.rstrip()): - # If tabs weren't converted to spaces, handle them now - if current_tabs: - # If the last thing we added was a space anyway then drop - # it. But let's not get rid of the indentation. - if (((result and line != indent) or - (not result and line != firstline_indent)) and line[-1] == ' '): - line = line[:-1] - # Add the tabs, if that means adding whitespace, just add it at - # the line, the rstrip() code while shorten the line down if - # necessary - if tabs_are_whitespace: - line += tabs * len(current_tabs) - else: - # if not all tab replacement is whitespace we prepend it to the word - word = tabs * len(current_tabs) + word - # Handle the case where word cannot be squeezed onto current last line - if len(line) + len(word) > length and len(indent) + len(word) <= length: - result.append(line.rstrip()) - line = indent + word - word = '' - # No space left on line or can we append a space? - if len(line) + 1 >= length: - result.append(line.rstrip()) - line = indent - else: - line += ' ' - # Add word and shorten it up to allowed line length. Restart next - # line with indent and repeat, or add a space if we're done (word - # finished) This deals with words that cannot fit on one line - # (e.g. indent + word longer than allowed line length). - while len(line) + len(word) >= length: - line += word - result.append(line[:length]) - word = line[length:] - line = indent - # Default case, simply append the word and a space - if word: - line += word + ' ' - # End of input line. If we have content we finish the line. If the - # current line is just the indent but we had content in during this - # original line then we need to add an empty line. - if (result and line != indent) or (not result and line != firstline_indent): - result.append(line.rstrip()) - elif len(result) == old_result_len: - result.append('') - line = indent - - return '\n'.join(result) - - -def DocToHelp(doc): - """Takes a __doc__ string and reformats it as help.""" - - # Get rid of starting and ending white space. Using lstrip() or even - # strip() could drop more than maximum of first line and right space - # of last line. - doc = doc.strip() - - # Get rid of all empty lines - whitespace_only_line = re.compile('^[ \t]+$', re.M) - doc = whitespace_only_line.sub('', doc) - - # Cut out common space at line beginnings - doc = CutCommonSpacePrefix(doc) - - # Just like this module's comment, comments tend to be aligned somehow. - # In other words they all start with the same amount of white space - # 1) keep double new lines - # 2) keep ws after new lines if not empty line - # 3) all other new lines shall be changed to a space - # Solution: Match new lines between non white space and replace with space. - doc = re.sub('(?<=\S)\n(?=\S)', ' ', doc, re.M) - - return doc - - -def _GetModuleObjectAndName(globals_dict): - """Returns the module that defines a global environment, and its name. - - Args: - globals_dict: A dictionary that should correspond to an environment - providing the values of the globals. - - Returns: - A pair consisting of (1) module object and (2) module name (a - string). Returns (None, None) if the module could not be - identified. - """ - # The use of .items() (instead of .iteritems()) is NOT a mistake: if - # a parallel thread imports a module while we iterate over - # .iteritems() (not nice, but possible), we get a RuntimeError ... - # Hence, we use the slightly slower but safer .items(). - for name, module in sys.modules.items(): - if getattr(module, '__dict__', None) is globals_dict: - if name == '__main__': - # Pick a more informative name for the main module. - name = sys.argv[0] - return (module, name) - return (None, None) - - -def _GetMainModule(): - """Returns: string, name of the module from which execution started.""" - # First, try to use the same logic used by _GetCallingModuleObjectAndName(), - # i.e., call _GetModuleObjectAndName(). For that we first need to - # find the dictionary that the main module uses to store the - # globals. - # - # That's (normally) the same dictionary object that the deepest - # (oldest) stack frame is using for globals. - deepest_frame = sys._getframe(0) - while deepest_frame.f_back is not None: - deepest_frame = deepest_frame.f_back - globals_for_main_module = deepest_frame.f_globals - main_module_name = _GetModuleObjectAndName(globals_for_main_module)[1] - # The above strategy fails in some cases (e.g., tools that compute - # code coverage by redefining, among other things, the main module). - # If so, just use sys.argv[0]. We can probably always do this, but - # it's safest to try to use the same logic as _GetCallingModuleObjectAndName() - if main_module_name is None: - main_module_name = sys.argv[0] - return main_module_name - - -class FlagValues: - """Registry of 'Flag' objects. - - A 'FlagValues' can then scan command line arguments, passing flag - arguments through to the 'Flag' objects that it owns. It also - provides easy access to the flag values. Typically only one - 'FlagValues' object is needed by an application: gflags.FLAGS - - This class is heavily overloaded: - - 'Flag' objects are registered via __setitem__: - FLAGS['longname'] = x # register a new flag - - The .value attribute of the registered 'Flag' objects can be accessed - as attributes of this 'FlagValues' object, through __getattr__. Both - the long and short name of the original 'Flag' objects can be used to - access its value: - FLAGS.longname # parsed flag value - FLAGS.x # parsed flag value (short name) - - Command line arguments are scanned and passed to the registered 'Flag' - objects through the __call__ method. Unparsed arguments, including - argv[0] (e.g. the program name) are returned. - argv = FLAGS(sys.argv) # scan command line arguments - - The original registered Flag objects can be retrieved through the use - of the dictionary-like operator, __getitem__: - x = FLAGS['longname'] # access the registered Flag object - - The str() operator of a 'FlagValues' object provides help for all of - the registered 'Flag' objects. - """ - - def __init__(self): - # Since everything in this class is so heavily overloaded, the only - # way of defining and using fields is to access __dict__ directly. - - # Dictionary: flag name (string) -> Flag object. - self.__dict__['__flags'] = {} - # Dictionary: module name (string) -> list of Flag objects that are defined - # by that module. - self.__dict__['__flags_by_module'] = {} - # Dictionary: module id (int) -> list of Flag objects that are defined by - # that module. - self.__dict__['__flags_by_module_id'] = {} - # Dictionary: module name (string) -> list of Flag objects that are - # key for that module. - self.__dict__['__key_flags_by_module'] = {} - - # Set if we should use new style gnu_getopt rather than getopt when parsing - # the args. Only possible with Python 2.3+ - self.UseGnuGetOpt(False) - - def UseGnuGetOpt(self, use_gnu_getopt=True): - """Use GNU-style scanning. Allows mixing of flag and non-flag arguments. - - See http://docs.python.org/library/getopt.html#getopt.gnu_getopt - - Args: - use_gnu_getopt: wether or not to use GNU style scanning. - """ - self.__dict__['__use_gnu_getopt'] = use_gnu_getopt - - def IsGnuGetOpt(self): - return self.__dict__['__use_gnu_getopt'] - - def FlagDict(self): - return self.__dict__['__flags'] - - def FlagsByModuleDict(self): - """Returns the dictionary of module_name -> list of defined flags. - - Returns: - A dictionary. Its keys are module names (strings). Its values - are lists of Flag objects. - """ - return self.__dict__['__flags_by_module'] - - def FlagsByModuleIdDict(self): - """Returns the dictionary of module_id -> list of defined flags. - - Returns: - A dictionary. Its keys are module IDs (ints). Its values - are lists of Flag objects. - """ - return self.__dict__['__flags_by_module_id'] - - def KeyFlagsByModuleDict(self): - """Returns the dictionary of module_name -> list of key flags. - - Returns: - A dictionary. Its keys are module names (strings). Its values - are lists of Flag objects. - """ - return self.__dict__['__key_flags_by_module'] - - def _RegisterFlagByModule(self, module_name, flag): - """Records the module that defines a specific flag. - - We keep track of which flag is defined by which module so that we - can later sort the flags by module. - - Args: - module_name: A string, the name of a Python module. - flag: A Flag object, a flag that is key to the module. - """ - flags_by_module = self.FlagsByModuleDict() - flags_by_module.setdefault(module_name, []).append(flag) - - def _RegisterFlagByModuleId(self, module_id, flag): - """Records the module that defines a specific flag. - - Args: - module_id: An int, the ID of the Python module. - flag: A Flag object, a flag that is key to the module. - """ - flags_by_module_id = self.FlagsByModuleIdDict() - flags_by_module_id.setdefault(module_id, []).append(flag) - - def _RegisterKeyFlagForModule(self, module_name, flag): - """Specifies that a flag is a key flag for a module. - - Args: - module_name: A string, the name of a Python module. - flag: A Flag object, a flag that is key to the module. - """ - key_flags_by_module = self.KeyFlagsByModuleDict() - # The list of key flags for the module named module_name. - key_flags = key_flags_by_module.setdefault(module_name, []) - # Add flag, but avoid duplicates. - if flag not in key_flags: - key_flags.append(flag) - - def _GetFlagsDefinedByModule(self, module): - """Returns the list of flags defined by a module. - - Args: - module: A module object or a module name (a string). - - Returns: - A new list of Flag objects. Caller may update this list as he - wishes: none of those changes will affect the internals of this - FlagValue object. - """ - if not isinstance(module, str): - module = module.__name__ - - return list(self.FlagsByModuleDict().get(module, [])) - - def _GetKeyFlagsForModule(self, module): - """Returns the list of key flags for a module. - - Args: - module: A module object or a module name (a string) - - Returns: - A new list of Flag objects. Caller may update this list as he - wishes: none of those changes will affect the internals of this - FlagValue object. - """ - if not isinstance(module, str): - module = module.__name__ - - # Any flag is a key flag for the module that defined it. NOTE: - # key_flags is a fresh list: we can update it without affecting the - # internals of this FlagValues object. - key_flags = self._GetFlagsDefinedByModule(module) - - # Take into account flags explicitly declared as key for a module. - for flag in self.KeyFlagsByModuleDict().get(module, []): - if flag not in key_flags: - key_flags.append(flag) - return key_flags - - def FindModuleDefiningFlag(self, flagname, default=None): - """Return the name of the module defining this flag, or default. - - Args: - flagname: Name of the flag to lookup. - default: Value to return if flagname is not defined. Defaults - to None. - - Returns: - The name of the module which registered the flag with this name. - If no such module exists (i.e. no flag with this name exists), - we return default. - """ - for module, flags in self.FlagsByModuleDict().iteritems(): - for flag in flags: - if flag.name == flagname or flag.short_name == flagname: - return module - return default - - def FindModuleIdDefiningFlag(self, flagname, default=None): - """Return the ID of the module defining this flag, or default. - - Args: - flagname: Name of the flag to lookup. - default: Value to return if flagname is not defined. Defaults - to None. - - Returns: - The ID of the module which registered the flag with this name. - If no such module exists (i.e. no flag with this name exists), - we return default. - """ - for module_id, flags in self.FlagsByModuleIdDict().iteritems(): - for flag in flags: - if flag.name == flagname or flag.short_name == flagname: - return module_id - return default - - def AppendFlagValues(self, flag_values): - """Appends flags registered in another FlagValues instance. - - Args: - flag_values: registry to copy from - """ - for flag_name, flag in flag_values.FlagDict().iteritems(): - # Each flags with shortname appears here twice (once under its - # normal name, and again with its short name). To prevent - # problems (DuplicateFlagError) with double flag registration, we - # perform a check to make sure that the entry we're looking at is - # for its normal name. - if flag_name == flag.name: - try: - self[flag_name] = flag - except DuplicateFlagError: - raise DuplicateFlagError(flag_name, self, - other_flag_values=flag_values) - - def RemoveFlagValues(self, flag_values): - """Remove flags that were previously appended from another FlagValues. - - Args: - flag_values: registry containing flags to remove. - """ - for flag_name in flag_values.FlagDict(): - self.__delattr__(flag_name) - - def __setitem__(self, name, flag): - """Registers a new flag variable.""" - fl = self.FlagDict() - if not isinstance(flag, Flag): - raise IllegalFlagValue(flag) - if not isinstance(name, type("")): - raise FlagsError("Flag name must be a string") - if len(name) == 0: - raise FlagsError("Flag name cannot be empty") - # If running under pychecker, duplicate keys are likely to be - # defined. Disable check for duplicate keys when pycheck'ing. - if (name in fl and not flag.allow_override and - not fl[name].allow_override and not _RUNNING_PYCHECKER): - module, module_name = _GetCallingModuleObjectAndName() - if (self.FindModuleDefiningFlag(name) == module_name and - id(module) != self.FindModuleIdDefiningFlag(name)): - # If the flag has already been defined by a module with the same name, - # but a different ID, we can stop here because it indicates that the - # module is simply being imported a subsequent time. - return - raise DuplicateFlagError(name, self) - short_name = flag.short_name - if short_name is not None: - if (short_name in fl and not flag.allow_override and - not fl[short_name].allow_override and not _RUNNING_PYCHECKER): - raise DuplicateFlagError(short_name, self) - fl[short_name] = flag - fl[name] = flag - global _exported_flags - _exported_flags[name] = flag - - def __getitem__(self, name): - """Retrieves the Flag object for the flag --name.""" - return self.FlagDict()[name] - - def __getattr__(self, name): - """Retrieves the 'value' attribute of the flag --name.""" - fl = self.FlagDict() - if name not in fl: - raise AttributeError(name) - return fl[name].value - - def __setattr__(self, name, value): - """Sets the 'value' attribute of the flag --name.""" - fl = self.FlagDict() - fl[name].value = value - self._AssertValidators(fl[name].validators) - return value - - def _AssertAllValidators(self): - all_validators = set() - for flag in self.FlagDict().itervalues(): - for validator in flag.validators: - all_validators.add(validator) - self._AssertValidators(all_validators) - - def _AssertValidators(self, validators): - """Assert if all validators in the list are satisfied. - - Asserts validators in the order they were created. - Args: - validators: Iterable(gflags_validators.Validator), validators to be - verified - Raises: - AttributeError: if validators work with a non-existing flag. - IllegalFlagValue: if validation fails for at least one validator - """ - for validator in sorted( - validators, key=lambda validator: validator.insertion_index): - try: - validator.Verify(self) - except gflags_validators.Error, e: - message = validator.PrintFlagsWithValues(self) - raise IllegalFlagValue('%s: %s' % (message, str(e))) - - def _FlagIsRegistered(self, flag_obj): - """Checks whether a Flag object is registered under some name. - - Note: this is non trivial: in addition to its normal name, a flag - may have a short name too. In self.FlagDict(), both the normal and - the short name are mapped to the same flag object. E.g., calling - only "del FLAGS.short_name" is not unregistering the corresponding - Flag object (it is still registered under the longer name). - - Args: - flag_obj: A Flag object. - - Returns: - A boolean: True iff flag_obj is registered under some name. - """ - flag_dict = self.FlagDict() - # Check whether flag_obj is registered under its long name. - name = flag_obj.name - if flag_dict.get(name, None) == flag_obj: - return True - # Check whether flag_obj is registered under its short name. - short_name = flag_obj.short_name - if (short_name is not None and - flag_dict.get(short_name, None) == flag_obj): - return True - # The flag cannot be registered under any other name, so we do not - # need to do a full search through the values of self.FlagDict(). - return False - - def __delattr__(self, flag_name): - """Deletes a previously-defined flag from a flag object. - - This method makes sure we can delete a flag by using - - del flag_values_object. - - E.g., - - gflags.DEFINE_integer('foo', 1, 'Integer flag.') - del gflags.FLAGS.foo - - Args: - flag_name: A string, the name of the flag to be deleted. - - Raises: - AttributeError: When there is no registered flag named flag_name. - """ - fl = self.FlagDict() - if flag_name not in fl: - raise AttributeError(flag_name) - - flag_obj = fl[flag_name] - del fl[flag_name] - - if not self._FlagIsRegistered(flag_obj): - # If the Flag object indicated by flag_name is no longer - # registered (please see the docstring of _FlagIsRegistered), then - # we delete the occurrences of the flag object in all our internal - # dictionaries. - self.__RemoveFlagFromDictByModule(self.FlagsByModuleDict(), flag_obj) - self.__RemoveFlagFromDictByModule(self.FlagsByModuleIdDict(), flag_obj) - self.__RemoveFlagFromDictByModule(self.KeyFlagsByModuleDict(), flag_obj) - - def __RemoveFlagFromDictByModule(self, flags_by_module_dict, flag_obj): - """Removes a flag object from a module -> list of flags dictionary. - - Args: - flags_by_module_dict: A dictionary that maps module names to lists of - flags. - flag_obj: A flag object. - """ - for unused_module, flags_in_module in flags_by_module_dict.iteritems(): - # while (as opposed to if) takes care of multiple occurrences of a - # flag in the list for the same module. - while flag_obj in flags_in_module: - flags_in_module.remove(flag_obj) - - def SetDefault(self, name, value): - """Changes the default value of the named flag object.""" - fl = self.FlagDict() - if name not in fl: - raise AttributeError(name) - fl[name].SetDefault(value) - self._AssertValidators(fl[name].validators) - - def __contains__(self, name): - """Returns True if name is a value (flag) in the dict.""" - return name in self.FlagDict() - - has_key = __contains__ # a synonym for __contains__() - - def __iter__(self): - return iter(self.FlagDict()) - - def __call__(self, argv): - """Parses flags from argv; stores parsed flags into this FlagValues object. - - All unparsed arguments are returned. Flags are parsed using the GNU - Program Argument Syntax Conventions, using getopt: - - http://www.gnu.org/software/libc/manual/html_mono/libc.html#Getopt - - Args: - argv: argument list. Can be of any type that may be converted to a list. - - Returns: - The list of arguments not parsed as options, including argv[0] - - Raises: - FlagsError: on any parsing error - """ - # Support any sequence type that can be converted to a list - argv = list(argv) - - shortopts = "" - longopts = [] - - fl = self.FlagDict() - - # This pre parses the argv list for --flagfile=<> options. - argv = argv[:1] + self.ReadFlagsFromFiles(argv[1:], force_gnu=False) - - # Correct the argv to support the google style of passing boolean - # parameters. Boolean parameters may be passed by using --mybool, - # --nomybool, --mybool=(true|false|1|0). getopt does not support - # having options that may or may not have a parameter. We replace - # instances of the short form --mybool and --nomybool with their - # full forms: --mybool=(true|false). - original_argv = list(argv) # list() makes a copy - shortest_matches = None - for name, flag in fl.items(): - if not flag.boolean: - continue - if shortest_matches is None: - # Determine the smallest allowable prefix for all flag names - shortest_matches = self.ShortestUniquePrefixes(fl) - no_name = 'no' + name - prefix = shortest_matches[name] - no_prefix = shortest_matches[no_name] - - # Replace all occurrences of this boolean with extended forms - for arg_idx in range(1, len(argv)): - arg = argv[arg_idx] - if arg.find('=') >= 0: continue - if arg.startswith('--'+prefix) and ('--'+name).startswith(arg): - argv[arg_idx] = ('--%s=true' % name) - elif arg.startswith('--'+no_prefix) and ('--'+no_name).startswith(arg): - argv[arg_idx] = ('--%s=false' % name) - - # Loop over all of the flags, building up the lists of short options - # and long options that will be passed to getopt. Short options are - # specified as a string of letters, each letter followed by a colon - # if it takes an argument. Long options are stored in an array of - # strings. Each string ends with an '=' if it takes an argument. - for name, flag in fl.items(): - longopts.append(name + "=") - if len(name) == 1: # one-letter option: allow short flag type also - shortopts += name - if not flag.boolean: - shortopts += ":" - - longopts.append('undefok=') - undefok_flags = [] - - # In case --undefok is specified, loop to pick up unrecognized - # options one by one. - unrecognized_opts = [] - args = argv[1:] - while True: - try: - if self.__dict__['__use_gnu_getopt']: - optlist, unparsed_args = getopt.gnu_getopt(args, shortopts, longopts) - else: - optlist, unparsed_args = getopt.getopt(args, shortopts, longopts) - break - except getopt.GetoptError, e: - if not e.opt or e.opt in fl: - # Not an unrecognized option, re-raise the exception as a FlagsError - raise FlagsError(e) - # Remove offender from args and try again - for arg_index in range(len(args)): - if ((args[arg_index] == '--' + e.opt) or - (args[arg_index] == '-' + e.opt) or - (args[arg_index].startswith('--' + e.opt + '='))): - unrecognized_opts.append((e.opt, args[arg_index])) - args = args[0:arg_index] + args[arg_index+1:] - break - else: - # We should have found the option, so we don't expect to get - # here. We could assert, but raising the original exception - # might work better. - raise FlagsError(e) - - for name, arg in optlist: - if name == '--undefok': - flag_names = arg.split(',') - undefok_flags.extend(flag_names) - # For boolean flags, if --undefok=boolflag is specified, then we should - # also accept --noboolflag, in addition to --boolflag. - # Since we don't know the type of the undefok'd flag, this will affect - # non-boolean flags as well. - # NOTE: You shouldn't use --undefok=noboolflag, because then we will - # accept --nonoboolflag here. We are choosing not to do the conversion - # from noboolflag -> boolflag because of the ambiguity that flag names - # can start with 'no'. - undefok_flags.extend('no' + name for name in flag_names) - continue - if name.startswith('--'): - # long option - name = name[2:] - short_option = 0 - else: - # short option - name = name[1:] - short_option = 1 - if name in fl: - flag = fl[name] - if flag.boolean and short_option: arg = 1 - flag.Parse(arg) - - # If there were unrecognized options, raise an exception unless - # the options were named via --undefok. - for opt, value in unrecognized_opts: - if opt not in undefok_flags: - raise UnrecognizedFlagError(opt, value) - - if unparsed_args: - if self.__dict__['__use_gnu_getopt']: - # if using gnu_getopt just return the program name + remainder of argv. - ret_val = argv[:1] + unparsed_args - else: - # unparsed_args becomes the first non-flag detected by getopt to - # the end of argv. Because argv may have been modified above, - # return original_argv for this region. - ret_val = argv[:1] + original_argv[-len(unparsed_args):] - else: - ret_val = argv[:1] - - self._AssertAllValidators() - return ret_val - - def Reset(self): - """Resets the values to the point before FLAGS(argv) was called.""" - for f in self.FlagDict().values(): - f.Unparse() - - def RegisteredFlags(self): - """Returns: a list of the names and short names of all registered flags.""" - return list(self.FlagDict()) - - def FlagValuesDict(self): - """Returns: a dictionary that maps flag names to flag values.""" - flag_values = {} - - for flag_name in self.RegisteredFlags(): - flag = self.FlagDict()[flag_name] - flag_values[flag_name] = flag.value - - return flag_values - - def __str__(self): - """Generates a help string for all known flags.""" - return self.GetHelp() - - def GetHelp(self, prefix=''): - """Generates a help string for all known flags.""" - helplist = [] - - flags_by_module = self.FlagsByModuleDict() - if flags_by_module: - - modules = sorted(flags_by_module) - - # Print the help for the main module first, if possible. - main_module = _GetMainModule() - if main_module in modules: - modules.remove(main_module) - modules = [main_module] + modules - - for module in modules: - self.__RenderOurModuleFlags(module, helplist) - - self.__RenderModuleFlags('gflags', - _SPECIAL_FLAGS.FlagDict().values(), - helplist) - - else: - # Just print one long list of flags. - self.__RenderFlagList( - self.FlagDict().values() + _SPECIAL_FLAGS.FlagDict().values(), - helplist, prefix) - - return '\n'.join(helplist) - - def __RenderModuleFlags(self, module, flags, output_lines, prefix=""): - """Generates a help string for a given module.""" - if not isinstance(module, str): - module = module.__name__ - output_lines.append('\n%s%s:' % (prefix, module)) - self.__RenderFlagList(flags, output_lines, prefix + " ") - - def __RenderOurModuleFlags(self, module, output_lines, prefix=""): - """Generates a help string for a given module.""" - flags = self._GetFlagsDefinedByModule(module) - if flags: - self.__RenderModuleFlags(module, flags, output_lines, prefix) - - def __RenderOurModuleKeyFlags(self, module, output_lines, prefix=""): - """Generates a help string for the key flags of a given module. - - Args: - module: A module object or a module name (a string). - output_lines: A list of strings. The generated help message - lines will be appended to this list. - prefix: A string that is prepended to each generated help line. - """ - key_flags = self._GetKeyFlagsForModule(module) - if key_flags: - self.__RenderModuleFlags(module, key_flags, output_lines, prefix) - - def ModuleHelp(self, module): - """Describe the key flags of a module. - - Args: - module: A module object or a module name (a string). - - Returns: - string describing the key flags of a module. - """ - helplist = [] - self.__RenderOurModuleKeyFlags(module, helplist) - return '\n'.join(helplist) - - def MainModuleHelp(self): - """Describe the key flags of the main module. - - Returns: - string describing the key flags of a module. - """ - return self.ModuleHelp(_GetMainModule()) - - def __RenderFlagList(self, flaglist, output_lines, prefix=" "): - fl = self.FlagDict() - special_fl = _SPECIAL_FLAGS.FlagDict() - flaglist = [(flag.name, flag) for flag in flaglist] - flaglist.sort() - flagset = {} - for (name, flag) in flaglist: - # It's possible this flag got deleted or overridden since being - # registered in the per-module flaglist. Check now against the - # canonical source of current flag information, the FlagDict. - if fl.get(name, None) != flag and special_fl.get(name, None) != flag: - # a different flag is using this name now - continue - # only print help once - if flag in flagset: continue - flagset[flag] = 1 - flaghelp = "" - if flag.short_name: flaghelp += "-%s," % flag.short_name - if flag.boolean: - flaghelp += "--[no]%s" % flag.name + ":" - else: - flaghelp += "--%s" % flag.name + ":" - flaghelp += " " - if flag.help: - flaghelp += flag.help - flaghelp = TextWrap(flaghelp, indent=prefix+" ", - firstline_indent=prefix) - if flag.default_as_str: - flaghelp += "\n" - flaghelp += TextWrap("(default: %s)" % flag.default_as_str, - indent=prefix+" ") - if flag.parser.syntactic_help: - flaghelp += "\n" - flaghelp += TextWrap("(%s)" % flag.parser.syntactic_help, - indent=prefix+" ") - output_lines.append(flaghelp) - - def get(self, name, default): - """Returns the value of a flag (if not None) or a default value. - - Args: - name: A string, the name of a flag. - default: Default value to use if the flag value is None. - """ - - value = self.__getattr__(name) - if value is not None: # Can't do if not value, b/c value might be '0' or "" - return value - else: - return default - - def ShortestUniquePrefixes(self, fl): - """Returns: dictionary; maps flag names to their shortest unique prefix.""" - # Sort the list of flag names - sorted_flags = [] - for name, flag in fl.items(): - sorted_flags.append(name) - if flag.boolean: - sorted_flags.append('no%s' % name) - sorted_flags.sort() - - # For each name in the sorted list, determine the shortest unique - # prefix by comparing itself to the next name and to the previous - # name (the latter check uses cached info from the previous loop). - shortest_matches = {} - prev_idx = 0 - for flag_idx in range(len(sorted_flags)): - curr = sorted_flags[flag_idx] - if flag_idx == (len(sorted_flags) - 1): - next = None - else: - next = sorted_flags[flag_idx+1] - next_len = len(next) - for curr_idx in range(len(curr)): - if (next is None - or curr_idx >= next_len - or curr[curr_idx] != next[curr_idx]): - # curr longer than next or no more chars in common - shortest_matches[curr] = curr[:max(prev_idx, curr_idx) + 1] - prev_idx = curr_idx - break - else: - # curr shorter than (or equal to) next - shortest_matches[curr] = curr - prev_idx = curr_idx + 1 # next will need at least one more char - return shortest_matches - - def __IsFlagFileDirective(self, flag_string): - """Checks whether flag_string contain a --flagfile= directive.""" - if isinstance(flag_string, type("")): - if flag_string.startswith('--flagfile='): - return 1 - elif flag_string == '--flagfile': - return 1 - elif flag_string.startswith('-flagfile='): - return 1 - elif flag_string == '-flagfile': - return 1 - else: - return 0 - return 0 - - def ExtractFilename(self, flagfile_str): - """Returns filename from a flagfile_str of form -[-]flagfile=filename. - - The cases of --flagfile foo and -flagfile foo shouldn't be hitting - this function, as they are dealt with in the level above this - function. - """ - if flagfile_str.startswith('--flagfile='): - return os.path.expanduser((flagfile_str[(len('--flagfile=')):]).strip()) - elif flagfile_str.startswith('-flagfile='): - return os.path.expanduser((flagfile_str[(len('-flagfile=')):]).strip()) - else: - raise FlagsError('Hit illegal --flagfile type: %s' % flagfile_str) - - def __GetFlagFileLines(self, filename, parsed_file_list): - """Returns the useful (!=comments, etc) lines from a file with flags. - - Args: - filename: A string, the name of the flag file. - parsed_file_list: A list of the names of the files we have - already read. MUTATED BY THIS FUNCTION. - - Returns: - List of strings. See the note below. - - NOTE(springer): This function checks for a nested --flagfile= - tag and handles the lower file recursively. It returns a list of - all the lines that _could_ contain command flags. This is - EVERYTHING except whitespace lines and comments (lines starting - with '#' or '//'). - """ - line_list = [] # All line from flagfile. - flag_line_list = [] # Subset of lines w/o comments, blanks, flagfile= tags. - try: - file_obj = open(filename, 'r') - except IOError, e_msg: - raise CantOpenFlagFileError('ERROR:: Unable to open flagfile: %s' % e_msg) - - line_list = file_obj.readlines() - file_obj.close() - parsed_file_list.append(filename) - - # This is where we check each line in the file we just read. - for line in line_list: - if line.isspace(): - pass - # Checks for comment (a line that starts with '#'). - elif line.startswith('#') or line.startswith('//'): - pass - # Checks for a nested "--flagfile=" flag in the current file. - # If we find one, recursively parse down into that file. - elif self.__IsFlagFileDirective(line): - sub_filename = self.ExtractFilename(line) - # We do a little safety check for reparsing a file we've already done. - if not sub_filename in parsed_file_list: - included_flags = self.__GetFlagFileLines(sub_filename, - parsed_file_list) - flag_line_list.extend(included_flags) - else: # Case of hitting a circularly included file. - sys.stderr.write('Warning: Hit circular flagfile dependency: %s\n' % - (sub_filename,)) - else: - # Any line that's not a comment or a nested flagfile should get - # copied into 2nd position. This leaves earlier arguments - # further back in the list, thus giving them higher priority. - flag_line_list.append(line.strip()) - return flag_line_list - - def ReadFlagsFromFiles(self, argv, force_gnu=True): - """Processes command line args, but also allow args to be read from file. - - Args: - argv: A list of strings, usually sys.argv[1:], which may contain one or - more flagfile directives of the form --flagfile="./filename". - Note that the name of the program (sys.argv[0]) should be omitted. - force_gnu: If False, --flagfile parsing obeys normal flag semantics. - If True, --flagfile parsing instead follows gnu_getopt semantics. - *** WARNING *** force_gnu=False may become the future default! - - Returns: - - A new list which has the original list combined with what we read - from any flagfile(s). - - References: Global gflags.FLAG class instance. - - This function should be called before the normal FLAGS(argv) call. - This function scans the input list for a flag that looks like: - --flagfile=. Then it opens , reads all valid key - and value pairs and inserts them into the input list between the - first item of the list and any subsequent items in the list. - - Note that your application's flags are still defined the usual way - using gflags DEFINE_flag() type functions. - - Notes (assuming we're getting a commandline of some sort as our input): - --> Flags from the command line argv _should_ always take precedence! - --> A further "--flagfile=" CAN be nested in a flagfile. - It will be processed after the parent flag file is done. - --> For duplicate flags, first one we hit should "win". - --> In a flagfile, a line beginning with # or // is a comment. - --> Entirely blank lines _should_ be ignored. - """ - parsed_file_list = [] - rest_of_args = argv - new_argv = [] - while rest_of_args: - current_arg = rest_of_args[0] - rest_of_args = rest_of_args[1:] - if self.__IsFlagFileDirective(current_arg): - # This handles the case of -(-)flagfile foo. In this case the - # next arg really is part of this one. - if current_arg == '--flagfile' or current_arg == '-flagfile': - if not rest_of_args: - raise IllegalFlagValue('--flagfile with no argument') - flag_filename = os.path.expanduser(rest_of_args[0]) - rest_of_args = rest_of_args[1:] - else: - # This handles the case of (-)-flagfile=foo. - flag_filename = self.ExtractFilename(current_arg) - new_argv.extend( - self.__GetFlagFileLines(flag_filename, parsed_file_list)) - else: - new_argv.append(current_arg) - # Stop parsing after '--', like getopt and gnu_getopt. - if current_arg == '--': - break - # Stop parsing after a non-flag, like getopt. - if not current_arg.startswith('-'): - if not force_gnu and not self.__dict__['__use_gnu_getopt']: - break - - if rest_of_args: - new_argv.extend(rest_of_args) - - return new_argv - - def FlagsIntoString(self): - """Returns a string with the flags assignments from this FlagValues object. - - This function ignores flags whose value is None. Each flag - assignment is separated by a newline. - - NOTE: MUST mirror the behavior of the C++ CommandlineFlagsIntoString - from http://code.google.com/p/google-gflags - """ - s = '' - for flag in self.FlagDict().values(): - if flag.value is not None: - s += flag.Serialize() + '\n' - return s - - def AppendFlagsIntoFile(self, filename): - """Appends all flags assignments from this FlagInfo object to a file. - - Output will be in the format of a flagfile. - - NOTE: MUST mirror the behavior of the C++ AppendFlagsIntoFile - from http://code.google.com/p/google-gflags - """ - out_file = open(filename, 'a') - out_file.write(self.FlagsIntoString()) - out_file.close() - - def WriteHelpInXMLFormat(self, outfile=None): - """Outputs flag documentation in XML format. - - NOTE: We use element names that are consistent with those used by - the C++ command-line flag library, from - http://code.google.com/p/google-gflags - We also use a few new elements (e.g., ), but we do not - interfere / overlap with existing XML elements used by the C++ - library. Please maintain this consistency. - - Args: - outfile: File object we write to. Default None means sys.stdout. - """ - outfile = outfile or sys.stdout - - outfile.write('\n') - outfile.write('\n') - indent = ' ' - _WriteSimpleXMLElement(outfile, 'program', os.path.basename(sys.argv[0]), - indent) - - usage_doc = sys.modules['__main__'].__doc__ - if not usage_doc: - usage_doc = '\nUSAGE: %s [flags]\n' % sys.argv[0] - else: - usage_doc = usage_doc.replace('%s', sys.argv[0]) - _WriteSimpleXMLElement(outfile, 'usage', usage_doc, indent) - - # Get list of key flags for the main module. - key_flags = self._GetKeyFlagsForModule(_GetMainModule()) - - # Sort flags by declaring module name and next by flag name. - flags_by_module = self.FlagsByModuleDict() - all_module_names = list(flags_by_module.keys()) - all_module_names.sort() - for module_name in all_module_names: - flag_list = [(f.name, f) for f in flags_by_module[module_name]] - flag_list.sort() - for unused_flag_name, flag in flag_list: - is_key = flag in key_flags - flag.WriteInfoInXMLFormat(outfile, module_name, - is_key=is_key, indent=indent) - - outfile.write('\n') - outfile.flush() - - def AddValidator(self, validator): - """Register new flags validator to be checked. - - Args: - validator: gflags_validators.Validator - Raises: - AttributeError: if validators work with a non-existing flag. - """ - for flag_name in validator.GetFlagsNames(): - flag = self.FlagDict()[flag_name] - flag.validators.append(validator) - -# end of FlagValues definition - - -# The global FlagValues instance -FLAGS = FlagValues() - - -def _StrOrUnicode(value): - """Converts value to a python string or, if necessary, unicode-string.""" - try: - return str(value) - except UnicodeEncodeError: - return unicode(value) - - -def _MakeXMLSafe(s): - """Escapes <, >, and & from s, and removes XML 1.0-illegal chars.""" - s = cgi.escape(s) # Escape <, >, and & - # Remove characters that cannot appear in an XML 1.0 document - # (http://www.w3.org/TR/REC-xml/#charsets). - # - # NOTE: if there are problems with current solution, one may move to - # XML 1.1, which allows such chars, if they're entity-escaped (&#xHH;). - s = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f]', '', s) - # Convert non-ascii characters to entities. Note: requires python >=2.3 - s = s.encode('ascii', 'xmlcharrefreplace') # u'\xce\x88' -> 'uΈ' - return s - - -def _WriteSimpleXMLElement(outfile, name, value, indent): - """Writes a simple XML element. - - Args: - outfile: File object we write the XML element to. - name: A string, the name of XML element. - value: A Python object, whose string representation will be used - as the value of the XML element. - indent: A string, prepended to each line of generated output. - """ - value_str = _StrOrUnicode(value) - if isinstance(value, bool): - # Display boolean values as the C++ flag library does: no caps. - value_str = value_str.lower() - safe_value_str = _MakeXMLSafe(value_str) - outfile.write('%s<%s>%s\n' % (indent, name, safe_value_str, name)) - - -class Flag: - """Information about a command-line flag. - - 'Flag' objects define the following fields: - .name - the name for this flag - .default - the default value for this flag - .default_as_str - default value as repr'd string, e.g., "'true'" (or None) - .value - the most recent parsed value of this flag; set by Parse() - .help - a help string or None if no help is available - .short_name - the single letter alias for this flag (or None) - .boolean - if 'true', this flag does not accept arguments - .present - true if this flag was parsed from command line flags. - .parser - an ArgumentParser object - .serializer - an ArgumentSerializer object - .allow_override - the flag may be redefined without raising an error - - The only public method of a 'Flag' object is Parse(), but it is - typically only called by a 'FlagValues' object. The Parse() method is - a thin wrapper around the 'ArgumentParser' Parse() method. The parsed - value is saved in .value, and the .present attribute is updated. If - this flag was already present, a FlagsError is raised. - - Parse() is also called during __init__ to parse the default value and - initialize the .value attribute. This enables other python modules to - safely use flags even if the __main__ module neglects to parse the - command line arguments. The .present attribute is cleared after - __init__ parsing. If the default value is set to None, then the - __init__ parsing step is skipped and the .value attribute is - initialized to None. - - Note: The default value is also presented to the user in the help - string, so it is important that it be a legal value for this flag. - """ - - def __init__(self, parser, serializer, name, default, help_string, - short_name=None, boolean=0, allow_override=0): - self.name = name - - if not help_string: - help_string = '(no help available)' - - self.help = help_string - self.short_name = short_name - self.boolean = boolean - self.present = 0 - self.parser = parser - self.serializer = serializer - self.allow_override = allow_override - self.value = None - self.validators = [] - - self.SetDefault(default) - - def __hash__(self): - return hash(id(self)) - - def __eq__(self, other): - return self is other - - def __lt__(self, other): - if isinstance(other, Flag): - return id(self) < id(other) - return NotImplemented - - def __GetParsedValueAsString(self, value): - if value is None: - return None - if self.serializer: - return repr(self.serializer.Serialize(value)) - if self.boolean: - if value: - return repr('true') - else: - return repr('false') - return repr(_StrOrUnicode(value)) - - def Parse(self, argument): - try: - self.value = self.parser.Parse(argument) - except ValueError, e: # recast ValueError as IllegalFlagValue - raise IllegalFlagValue("flag --%s=%s: %s" % (self.name, argument, e)) - self.present += 1 - - def Unparse(self): - if self.default is None: - self.value = None - else: - self.Parse(self.default) - self.present = 0 - - def Serialize(self): - if self.value is None: - return '' - if self.boolean: - if self.value: - return "--%s" % self.name - else: - return "--no%s" % self.name - else: - if not self.serializer: - raise FlagsError("Serializer not present for flag %s" % self.name) - return "--%s=%s" % (self.name, self.serializer.Serialize(self.value)) - - def SetDefault(self, value): - """Changes the default value (and current value too) for this Flag.""" - # We can't allow a None override because it may end up not being - # passed to C++ code when we're overriding C++ flags. So we - # cowardly bail out until someone fixes the semantics of trying to - # pass None to a C++ flag. See swig_flags.Init() for details on - # this behavior. - # TODO(olexiy): Users can directly call this method, bypassing all flags - # validators (we don't have FlagValues here, so we can not check - # validators). - # The simplest solution I see is to make this method private. - # Another approach would be to store reference to the corresponding - # FlagValues with each flag, but this seems to be an overkill. - if value is None and self.allow_override: - raise DuplicateFlagCannotPropagateNoneToSwig(self.name) - - self.default = value - self.Unparse() - self.default_as_str = self.__GetParsedValueAsString(self.value) - - def Type(self): - """Returns: a string that describes the type of this Flag.""" - # NOTE: we use strings, and not the types.*Type constants because - # our flags can have more exotic types, e.g., 'comma separated list - # of strings', 'whitespace separated list of strings', etc. - return self.parser.Type() - - def WriteInfoInXMLFormat(self, outfile, module_name, is_key=False, indent=''): - """Writes common info about this flag, in XML format. - - This is information that is relevant to all flags (e.g., name, - meaning, etc.). If you defined a flag that has some other pieces of - info, then please override _WriteCustomInfoInXMLFormat. - - Please do NOT override this method. - - Args: - outfile: File object we write to. - module_name: A string, the name of the module that defines this flag. - is_key: A boolean, True iff this flag is key for main module. - indent: A string that is prepended to each generated line. - """ - outfile.write(indent + '\n') - inner_indent = indent + ' ' - if is_key: - _WriteSimpleXMLElement(outfile, 'key', 'yes', inner_indent) - _WriteSimpleXMLElement(outfile, 'file', module_name, inner_indent) - # Print flag features that are relevant for all flags. - _WriteSimpleXMLElement(outfile, 'name', self.name, inner_indent) - if self.short_name: - _WriteSimpleXMLElement(outfile, 'short_name', self.short_name, - inner_indent) - if self.help: - _WriteSimpleXMLElement(outfile, 'meaning', self.help, inner_indent) - # The default flag value can either be represented as a string like on the - # command line, or as a Python object. We serialize this value in the - # latter case in order to remain consistent. - if self.serializer and not isinstance(self.default, str): - default_serialized = self.serializer.Serialize(self.default) - else: - default_serialized = self.default - _WriteSimpleXMLElement(outfile, 'default', default_serialized, inner_indent) - _WriteSimpleXMLElement(outfile, 'current', self.value, inner_indent) - _WriteSimpleXMLElement(outfile, 'type', self.Type(), inner_indent) - # Print extra flag features this flag may have. - self._WriteCustomInfoInXMLFormat(outfile, inner_indent) - outfile.write(indent + '\n') - - def _WriteCustomInfoInXMLFormat(self, outfile, indent): - """Writes extra info about this flag, in XML format. - - "Extra" means "not already printed by WriteInfoInXMLFormat above." - - Args: - outfile: File object we write to. - indent: A string that is prepended to each generated line. - """ - # Usually, the parser knows the extra details about the flag, so - # we just forward the call to it. - self.parser.WriteCustomInfoInXMLFormat(outfile, indent) -# End of Flag definition - - -class _ArgumentParserCache(type): - """Metaclass used to cache and share argument parsers among flags.""" - - _instances = {} - - def __call__(mcs, *args, **kwargs): - """Returns an instance of the argument parser cls. - - This method overrides behavior of the __new__ methods in - all subclasses of ArgumentParser (inclusive). If an instance - for mcs with the same set of arguments exists, this instance is - returned, otherwise a new instance is created. - - If any keyword arguments are defined, or the values in args - are not hashable, this method always returns a new instance of - cls. - - Args: - args: Positional initializer arguments. - kwargs: Initializer keyword arguments. - - Returns: - An instance of cls, shared or new. - """ - if kwargs: - return type.__call__(mcs, *args, **kwargs) - else: - instances = mcs._instances - key = (mcs,) + tuple(args) - try: - return instances[key] - except KeyError: - # No cache entry for key exists, create a new one. - return instances.setdefault(key, type.__call__(mcs, *args)) - except TypeError: - # An object in args cannot be hashed, always return - # a new instance. - return type.__call__(mcs, *args) - - -class ArgumentParser(object): - """Base class used to parse and convert arguments. - - The Parse() method checks to make sure that the string argument is a - legal value and convert it to a native type. If the value cannot be - converted, it should throw a 'ValueError' exception with a human - readable explanation of why the value is illegal. - - Subclasses should also define a syntactic_help string which may be - presented to the user to describe the form of the legal values. - - Argument parser classes must be stateless, since instances are cached - and shared between flags. Initializer arguments are allowed, but all - member variables must be derived from initializer arguments only. - """ - __metaclass__ = _ArgumentParserCache - - syntactic_help = "" - - def Parse(self, argument): - """Default implementation: always returns its argument unmodified.""" - return argument - - def Type(self): - return 'string' - - def WriteCustomInfoInXMLFormat(self, outfile, indent): - pass - - -class ArgumentSerializer: - """Base class for generating string representations of a flag value.""" - - def Serialize(self, value): - return _StrOrUnicode(value) - - -class ListSerializer(ArgumentSerializer): - - def __init__(self, list_sep): - self.list_sep = list_sep - - def Serialize(self, value): - return self.list_sep.join([_StrOrUnicode(x) for x in value]) - - -# Flags validators - - -def RegisterValidator(flag_name, - checker, - message='Flag validation failed', - flag_values=FLAGS): - """Adds a constraint, which will be enforced during program execution. - - The constraint is validated when flags are initially parsed, and after each - change of the corresponding flag's value. - Args: - flag_name: string, name of the flag to be checked. - checker: method to validate the flag. - input - value of the corresponding flag (string, boolean, etc. - This value will be passed to checker by the library). See file's - docstring for examples. - output - Boolean. - Must return True if validator constraint is satisfied. - If constraint is not satisfied, it should either return False or - raise gflags_validators.Error(desired_error_message). - message: error text to be shown to the user if checker returns False. - If checker raises gflags_validators.Error, message from the raised - Error will be shown. - flag_values: FlagValues - Raises: - AttributeError: if flag_name is not registered as a valid flag name. - """ - flag_values.AddValidator(gflags_validators.SimpleValidator(flag_name, - checker, - message)) - - -def MarkFlagAsRequired(flag_name, flag_values=FLAGS): - """Ensure that flag is not None during program execution. - - Registers a flag validator, which will follow usual validator - rules. - Args: - flag_name: string, name of the flag - flag_values: FlagValues - Raises: - AttributeError: if flag_name is not registered as a valid flag name. - """ - RegisterValidator(flag_name, - lambda value: value is not None, - message='Flag --%s must be specified.' % flag_name, - flag_values=flag_values) - - -def _RegisterBoundsValidatorIfNeeded(parser, name, flag_values): - """Enforce lower and upper bounds for numeric flags. - - Args: - parser: NumericParser (either FloatParser or IntegerParser). Provides lower - and upper bounds, and help text to display. - name: string, name of the flag - flag_values: FlagValues - """ - if parser.lower_bound is not None or parser.upper_bound is not None: - - def Checker(value): - if value is not None and parser.IsOutsideBounds(value): - message = '%s is not %s' % (value, parser.syntactic_help) - raise gflags_validators.Error(message) - return True - - RegisterValidator(name, - Checker, - flag_values=flag_values) - - -# The DEFINE functions are explained in mode details in the module doc string. - - -def DEFINE(parser, name, default, help, flag_values=FLAGS, serializer=None, - **args): - """Registers a generic Flag object. - - NOTE: in the docstrings of all DEFINE* functions, "registers" is short - for "creates a new flag and registers it". - - Auxiliary function: clients should use the specialized DEFINE_ - function instead. - - Args: - parser: ArgumentParser that is used to parse the flag arguments. - name: A string, the flag name. - default: The default value of the flag. - help: A help string. - flag_values: FlagValues object the flag will be registered with. - serializer: ArgumentSerializer that serializes the flag value. - args: Dictionary with extra keyword args that are passes to the - Flag __init__. - """ - DEFINE_flag(Flag(parser, serializer, name, default, help, **args), - flag_values) - - -def DEFINE_flag(flag, flag_values=FLAGS): - """Registers a 'Flag' object with a 'FlagValues' object. - - By default, the global FLAGS 'FlagValue' object is used. - - Typical users will use one of the more specialized DEFINE_xxx - functions, such as DEFINE_string or DEFINE_integer. But developers - who need to create Flag objects themselves should use this function - to register their flags. - """ - # copying the reference to flag_values prevents pychecker warnings - fv = flag_values - fv[flag.name] = flag - # Tell flag_values who's defining the flag. - if isinstance(flag_values, FlagValues): - # Regarding the above isinstance test: some users pass funny - # values of flag_values (e.g., {}) in order to avoid the flag - # registration (in the past, there used to be a flag_values == - # FLAGS test here) and redefine flags with the same name (e.g., - # debug). To avoid breaking their code, we perform the - # registration only if flag_values is a real FlagValues object. - module, module_name = _GetCallingModuleObjectAndName() - flag_values._RegisterFlagByModule(module_name, flag) - flag_values._RegisterFlagByModuleId(id(module), flag) - - -def _InternalDeclareKeyFlags(flag_names, - flag_values=FLAGS, key_flag_values=None): - """Declares a flag as key for the calling module. - - Internal function. User code should call DECLARE_key_flag or - ADOPT_module_key_flags instead. - - Args: - flag_names: A list of strings that are names of already-registered - Flag objects. - flag_values: A FlagValues object that the flags listed in - flag_names have registered with (the value of the flag_values - argument from the DEFINE_* calls that defined those flags). - This should almost never need to be overridden. - key_flag_values: A FlagValues object that (among possibly many - other things) keeps track of the key flags for each module. - Default None means "same as flag_values". This should almost - never need to be overridden. - - Raises: - UnrecognizedFlagError: when we refer to a flag that was not - defined yet. - """ - key_flag_values = key_flag_values or flag_values - - module = _GetCallingModule() - - for flag_name in flag_names: - if flag_name not in flag_values: - raise UnrecognizedFlagError(flag_name) - flag = flag_values.FlagDict()[flag_name] - key_flag_values._RegisterKeyFlagForModule(module, flag) - - -def DECLARE_key_flag(flag_name, flag_values=FLAGS): - """Declares one flag as key to the current module. - - Key flags are flags that are deemed really important for a module. - They are important when listing help messages; e.g., if the - --helpshort command-line flag is used, then only the key flags of the - main module are listed (instead of all flags, as in the case of - --help). - - Sample usage: - - gflags.DECLARED_key_flag('flag_1') - - Args: - flag_name: A string, the name of an already declared flag. - (Redeclaring flags as key, including flags implicitly key - because they were declared in this module, is a no-op.) - flag_values: A FlagValues object. This should almost never - need to be overridden. - """ - if flag_name in _SPECIAL_FLAGS: - # Take care of the special flags, e.g., --flagfile, --undefok. - # These flags are defined in _SPECIAL_FLAGS, and are treated - # specially during flag parsing, taking precedence over the - # user-defined flags. - _InternalDeclareKeyFlags([flag_name], - flag_values=_SPECIAL_FLAGS, - key_flag_values=flag_values) - return - _InternalDeclareKeyFlags([flag_name], flag_values=flag_values) - - -def ADOPT_module_key_flags(module, flag_values=FLAGS): - """Declares that all flags key to a module are key to the current module. - - Args: - module: A module object. - flag_values: A FlagValues object. This should almost never need - to be overridden. - - Raises: - FlagsError: When given an argument that is a module name (a - string), instead of a module object. - """ - # NOTE(salcianu): an even better test would be if not - # isinstance(module, types.ModuleType) but I didn't want to import - # types for such a tiny use. - if isinstance(module, str): - raise FlagsError('Received module name %s; expected a module object.' - % module) - _InternalDeclareKeyFlags( - [f.name for f in flag_values._GetKeyFlagsForModule(module.__name__)], - flag_values=flag_values) - # If module is this flag module, take _SPECIAL_FLAGS into account. - if module == _GetThisModuleObjectAndName()[0]: - _InternalDeclareKeyFlags( - # As we associate flags with _GetCallingModuleObjectAndName(), the - # special flags defined in this module are incorrectly registered with - # a different module. So, we can't use _GetKeyFlagsForModule. - # Instead, we take all flags from _SPECIAL_FLAGS (a private - # FlagValues, where no other module should register flags). - [f.name for f in _SPECIAL_FLAGS.FlagDict().values()], - flag_values=_SPECIAL_FLAGS, - key_flag_values=flag_values) - - -# -# STRING FLAGS -# - - -def DEFINE_string(name, default, help, flag_values=FLAGS, **args): - """Registers a flag whose value can be any string.""" - parser = ArgumentParser() - serializer = ArgumentSerializer() - DEFINE(parser, name, default, help, flag_values, serializer, **args) - - -# -# BOOLEAN FLAGS -# - - -class BooleanParser(ArgumentParser): - """Parser of boolean values.""" - - def Convert(self, argument): - """Converts the argument to a boolean; raise ValueError on errors.""" - if type(argument) == str: - if argument.lower() in ['true', 't', '1']: - return True - elif argument.lower() in ['false', 'f', '0']: - return False - - bool_argument = bool(argument) - if argument == bool_argument: - # The argument is a valid boolean (True, False, 0, or 1), and not just - # something that always converts to bool (list, string, int, etc.). - return bool_argument - - raise ValueError('Non-boolean argument to boolean flag', argument) - - def Parse(self, argument): - val = self.Convert(argument) - return val - - def Type(self): - return 'bool' - - -class BooleanFlag(Flag): - """Basic boolean flag. - - Boolean flags do not take any arguments, and their value is either - True (1) or False (0). The false value is specified on the command - line by prepending the word 'no' to either the long or the short flag - name. - - For example, if a Boolean flag was created whose long name was - 'update' and whose short name was 'x', then this flag could be - explicitly unset through either --noupdate or --nox. - """ - - def __init__(self, name, default, help, short_name=None, **args): - p = BooleanParser() - Flag.__init__(self, p, None, name, default, help, short_name, 1, **args) - if not self.help: self.help = "a boolean value" - - -def DEFINE_boolean(name, default, help, flag_values=FLAGS, **args): - """Registers a boolean flag. - - Such a boolean flag does not take an argument. If a user wants to - specify a false value explicitly, the long option beginning with 'no' - must be used: i.e. --noflag - - This flag will have a value of None, True or False. None is possible - if default=None and the user does not specify the flag on the command - line. - """ - DEFINE_flag(BooleanFlag(name, default, help, **args), flag_values) - - -# Match C++ API to unconfuse C++ people. -DEFINE_bool = DEFINE_boolean - - -class HelpFlag(BooleanFlag): - """ - HelpFlag is a special boolean flag that prints usage information and - raises a SystemExit exception if it is ever found in the command - line arguments. Note this is called with allow_override=1, so other - apps can define their own --help flag, replacing this one, if they want. - """ - def __init__(self): - BooleanFlag.__init__(self, "help", 0, "show this help", - short_name="?", allow_override=1) - def Parse(self, arg): - if arg: - doc = sys.modules["__main__"].__doc__ - flags = str(FLAGS) - print doc or ("\nUSAGE: %s [flags]\n" % sys.argv[0]) - if flags: - print "flags:" - print flags - sys.exit(1) -class HelpXMLFlag(BooleanFlag): - """Similar to HelpFlag, but generates output in XML format.""" - def __init__(self): - BooleanFlag.__init__(self, 'helpxml', False, - 'like --help, but generates XML output', - allow_override=1) - def Parse(self, arg): - if arg: - FLAGS.WriteHelpInXMLFormat(sys.stdout) - sys.exit(1) -class HelpshortFlag(BooleanFlag): - """ - HelpshortFlag is a special boolean flag that prints usage - information for the "main" module, and rasies a SystemExit exception - if it is ever found in the command line arguments. Note this is - called with allow_override=1, so other apps can define their own - --helpshort flag, replacing this one, if they want. - """ - def __init__(self): - BooleanFlag.__init__(self, "helpshort", 0, - "show usage only for this module", allow_override=1) - def Parse(self, arg): - if arg: - doc = sys.modules["__main__"].__doc__ - flags = FLAGS.MainModuleHelp() - print doc or ("\nUSAGE: %s [flags]\n" % sys.argv[0]) - if flags: - print "flags:" - print flags - sys.exit(1) - -# -# Numeric parser - base class for Integer and Float parsers -# - - -class NumericParser(ArgumentParser): - """Parser of numeric values. - - Parsed value may be bounded to a given upper and lower bound. - """ - - def IsOutsideBounds(self, val): - return ((self.lower_bound is not None and val < self.lower_bound) or - (self.upper_bound is not None and val > self.upper_bound)) - - def Parse(self, argument): - val = self.Convert(argument) - if self.IsOutsideBounds(val): - raise ValueError("%s is not %s" % (val, self.syntactic_help)) - return val - - def WriteCustomInfoInXMLFormat(self, outfile, indent): - if self.lower_bound is not None: - _WriteSimpleXMLElement(outfile, 'lower_bound', self.lower_bound, indent) - if self.upper_bound is not None: - _WriteSimpleXMLElement(outfile, 'upper_bound', self.upper_bound, indent) - - def Convert(self, argument): - """Default implementation: always returns its argument unmodified.""" - return argument - -# End of Numeric Parser - -# -# FLOAT FLAGS -# - - -class FloatParser(NumericParser): - """Parser of floating point values. - - Parsed value may be bounded to a given upper and lower bound. - """ - number_article = "a" - number_name = "number" - syntactic_help = " ".join((number_article, number_name)) - - def __init__(self, lower_bound=None, upper_bound=None): - super(FloatParser, self).__init__() - self.lower_bound = lower_bound - self.upper_bound = upper_bound - sh = self.syntactic_help - if lower_bound is not None and upper_bound is not None: - sh = ("%s in the range [%s, %s]" % (sh, lower_bound, upper_bound)) - elif lower_bound == 0: - sh = "a non-negative %s" % self.number_name - elif upper_bound == 0: - sh = "a non-positive %s" % self.number_name - elif upper_bound is not None: - sh = "%s <= %s" % (self.number_name, upper_bound) - elif lower_bound is not None: - sh = "%s >= %s" % (self.number_name, lower_bound) - self.syntactic_help = sh - - def Convert(self, argument): - """Converts argument to a float; raises ValueError on errors.""" - return float(argument) - - def Type(self): - return 'float' -# End of FloatParser - - -def DEFINE_float(name, default, help, lower_bound=None, upper_bound=None, - flag_values=FLAGS, **args): - """Registers a flag whose value must be a float. - - If lower_bound or upper_bound are set, then this flag must be - within the given range. - """ - parser = FloatParser(lower_bound, upper_bound) - serializer = ArgumentSerializer() - DEFINE(parser, name, default, help, flag_values, serializer, **args) - _RegisterBoundsValidatorIfNeeded(parser, name, flag_values=flag_values) - -# -# INTEGER FLAGS -# - - -class IntegerParser(NumericParser): - """Parser of an integer value. - - Parsed value may be bounded to a given upper and lower bound. - """ - number_article = "an" - number_name = "integer" - syntactic_help = " ".join((number_article, number_name)) - - def __init__(self, lower_bound=None, upper_bound=None): - super(IntegerParser, self).__init__() - self.lower_bound = lower_bound - self.upper_bound = upper_bound - sh = self.syntactic_help - if lower_bound is not None and upper_bound is not None: - sh = ("%s in the range [%s, %s]" % (sh, lower_bound, upper_bound)) - elif lower_bound == 1: - sh = "a positive %s" % self.number_name - elif upper_bound == -1: - sh = "a negative %s" % self.number_name - elif lower_bound == 0: - sh = "a non-negative %s" % self.number_name - elif upper_bound == 0: - sh = "a non-positive %s" % self.number_name - elif upper_bound is not None: - sh = "%s <= %s" % (self.number_name, upper_bound) - elif lower_bound is not None: - sh = "%s >= %s" % (self.number_name, lower_bound) - self.syntactic_help = sh - - def Convert(self, argument): - __pychecker__ = 'no-returnvalues' - if type(argument) == str: - base = 10 - if len(argument) > 2 and argument[0] == "0" and argument[1] == "x": - base = 16 - return int(argument, base) - else: - return int(argument) - - def Type(self): - return 'int' - - -def DEFINE_integer(name, default, help, lower_bound=None, upper_bound=None, - flag_values=FLAGS, **args): - """Registers a flag whose value must be an integer. - - If lower_bound, or upper_bound are set, then this flag must be - within the given range. - """ - parser = IntegerParser(lower_bound, upper_bound) - serializer = ArgumentSerializer() - DEFINE(parser, name, default, help, flag_values, serializer, **args) - _RegisterBoundsValidatorIfNeeded(parser, name, flag_values=flag_values) - - -# -# ENUM FLAGS -# - - -class EnumParser(ArgumentParser): - """Parser of a string enum value (a string value from a given set). - - If enum_values (see below) is not specified, any string is allowed. - """ - - def __init__(self, enum_values=None): - super(EnumParser, self).__init__() - self.enum_values = enum_values - - def Parse(self, argument): - if self.enum_values and argument not in self.enum_values: - raise ValueError("value should be one of <%s>" % - "|".join(self.enum_values)) - return argument - - def Type(self): - return 'string enum' - - -class EnumFlag(Flag): - """Basic enum flag; its value can be any string from list of enum_values.""" - - def __init__(self, name, default, help, enum_values=None, - short_name=None, **args): - enum_values = enum_values or [] - p = EnumParser(enum_values) - g = ArgumentSerializer() - Flag.__init__(self, p, g, name, default, help, short_name, **args) - if not self.help: self.help = "an enum string" - self.help = "<%s>: %s" % ("|".join(enum_values), self.help) - - def _WriteCustomInfoInXMLFormat(self, outfile, indent): - for enum_value in self.parser.enum_values: - _WriteSimpleXMLElement(outfile, 'enum_value', enum_value, indent) - - -def DEFINE_enum(name, default, enum_values, help, flag_values=FLAGS, - **args): - """Registers a flag whose value can be any string from enum_values.""" - DEFINE_flag(EnumFlag(name, default, help, enum_values, ** args), - flag_values) - - -# -# LIST FLAGS -# - - -class BaseListParser(ArgumentParser): - """Base class for a parser of lists of strings. - - To extend, inherit from this class; from the subclass __init__, call - - BaseListParser.__init__(self, token, name) - - where token is a character used to tokenize, and name is a description - of the separator. - """ - - def __init__(self, token=None, name=None): - assert name - super(BaseListParser, self).__init__() - self._token = token - self._name = name - self.syntactic_help = "a %s separated list" % self._name - - def Parse(self, argument): - if isinstance(argument, list): - return argument - elif argument == '': - return [] - else: - return [s.strip() for s in argument.split(self._token)] - - def Type(self): - return '%s separated list of strings' % self._name - - -class ListParser(BaseListParser): - """Parser for a comma-separated list of strings.""" - - def __init__(self): - BaseListParser.__init__(self, ',', 'comma') - - def WriteCustomInfoInXMLFormat(self, outfile, indent): - BaseListParser.WriteCustomInfoInXMLFormat(self, outfile, indent) - _WriteSimpleXMLElement(outfile, 'list_separator', repr(','), indent) - - -class WhitespaceSeparatedListParser(BaseListParser): - """Parser for a whitespace-separated list of strings.""" - - def __init__(self): - BaseListParser.__init__(self, None, 'whitespace') - - def WriteCustomInfoInXMLFormat(self, outfile, indent): - BaseListParser.WriteCustomInfoInXMLFormat(self, outfile, indent) - separators = list(string.whitespace) - separators.sort() - for ws_char in string.whitespace: - _WriteSimpleXMLElement(outfile, 'list_separator', repr(ws_char), indent) - - -def DEFINE_list(name, default, help, flag_values=FLAGS, **args): - """Registers a flag whose value is a comma-separated list of strings.""" - parser = ListParser() - serializer = ListSerializer(',') - DEFINE(parser, name, default, help, flag_values, serializer, **args) - - -def DEFINE_spaceseplist(name, default, help, flag_values=FLAGS, **args): - """Registers a flag whose value is a whitespace-separated list of strings. - - Any whitespace can be used as a separator. - """ - parser = WhitespaceSeparatedListParser() - serializer = ListSerializer(' ') - DEFINE(parser, name, default, help, flag_values, serializer, **args) - - -# -# MULTI FLAGS -# - - -class MultiFlag(Flag): - """A flag that can appear multiple time on the command-line. - - The value of such a flag is a list that contains the individual values - from all the appearances of that flag on the command-line. - - See the __doc__ for Flag for most behavior of this class. Only - differences in behavior are described here: - - * The default value may be either a single value or a list of values. - A single value is interpreted as the [value] singleton list. - - * The value of the flag is always a list, even if the option was - only supplied once, and even if the default value is a single - value - """ - - def __init__(self, *args, **kwargs): - Flag.__init__(self, *args, **kwargs) - self.help += ';\n repeat this option to specify a list of values' - - def Parse(self, arguments): - """Parses one or more arguments with the installed parser. - - Args: - arguments: a single argument or a list of arguments (typically a - list of default values); a single argument is converted - internally into a list containing one item. - """ - if not isinstance(arguments, list): - # Default value may be a list of values. Most other arguments - # will not be, so convert them into a single-item list to make - # processing simpler below. - arguments = [arguments] - - if self.present: - # keep a backup reference to list of previously supplied option values - values = self.value - else: - # "erase" the defaults with an empty list - values = [] - - for item in arguments: - # have Flag superclass parse argument, overwriting self.value reference - Flag.Parse(self, item) # also increments self.present - values.append(self.value) - - # put list of option values back in the 'value' attribute - self.value = values - - def Serialize(self): - if not self.serializer: - raise FlagsError("Serializer not present for flag %s" % self.name) - if self.value is None: - return '' - - s = '' - - multi_value = self.value - - for self.value in multi_value: - if s: s += ' ' - s += Flag.Serialize(self) - - self.value = multi_value - - return s - - def Type(self): - return 'multi ' + self.parser.Type() - - -def DEFINE_multi(parser, serializer, name, default, help, flag_values=FLAGS, - **args): - """Registers a generic MultiFlag that parses its args with a given parser. - - Auxiliary function. Normal users should NOT use it directly. - - Developers who need to create their own 'Parser' classes for options - which can appear multiple times can call this module function to - register their flags. - """ - DEFINE_flag(MultiFlag(parser, serializer, name, default, help, **args), - flag_values) - - -def DEFINE_multistring(name, default, help, flag_values=FLAGS, **args): - """Registers a flag whose value can be a list of any strings. - - Use the flag on the command line multiple times to place multiple - string values into the list. The 'default' may be a single string - (which will be converted into a single-element list) or a list of - strings. - """ - parser = ArgumentParser() - serializer = ArgumentSerializer() - DEFINE_multi(parser, serializer, name, default, help, flag_values, **args) - - -def DEFINE_multi_int(name, default, help, lower_bound=None, upper_bound=None, - flag_values=FLAGS, **args): - """Registers a flag whose value can be a list of arbitrary integers. - - Use the flag on the command line multiple times to place multiple - integer values into the list. The 'default' may be a single integer - (which will be converted into a single-element list) or a list of - integers. - """ - parser = IntegerParser(lower_bound, upper_bound) - serializer = ArgumentSerializer() - DEFINE_multi(parser, serializer, name, default, help, flag_values, **args) - - -def DEFINE_multi_float(name, default, help, lower_bound=None, upper_bound=None, - flag_values=FLAGS, **args): - """Registers a flag whose value can be a list of arbitrary floats. - - Use the flag on the command line multiple times to place multiple - float values into the list. The 'default' may be a single float - (which will be converted into a single-element list) or a list of - floats. - """ - parser = FloatParser(lower_bound, upper_bound) - serializer = ArgumentSerializer() - DEFINE_multi(parser, serializer, name, default, help, flag_values, **args) - - -# Now register the flags that we want to exist in all applications. -# These are all defined with allow_override=1, so user-apps can use -# these flagnames for their own purposes, if they want. -DEFINE_flag(HelpFlag()) -DEFINE_flag(HelpshortFlag()) -DEFINE_flag(HelpXMLFlag()) - -# Define special flags here so that help may be generated for them. -# NOTE: Please do NOT use _SPECIAL_FLAGS from outside this module. -_SPECIAL_FLAGS = FlagValues() - - -DEFINE_string( - 'flagfile', "", - "Insert flag definitions from the given file into the command line.", - _SPECIAL_FLAGS) - -DEFINE_string( - 'undefok', "", - "comma-separated list of flag names that it is okay to specify " - "on the command line even if the program does not define a flag " - "with that name. IMPORTANT: flags in this list that have " - "arguments MUST use the --flag=value format.", _SPECIAL_FLAGS) diff --git a/gflags_validators.py b/gflags_validators.py deleted file mode 100644 index d83058d..0000000 --- a/gflags_validators.py +++ /dev/null @@ -1,187 +0,0 @@ -#!/usr/bin/env python - -# Copyright (c) 2010, Google Inc. -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Module to enforce different constraints on flags. - -A validator represents an invariant, enforced over a one or more flags. -See 'FLAGS VALIDATORS' in gflags.py's docstring for a usage manual. -""" - -__author__ = 'olexiy@google.com (Olexiy Oryeshko)' - - -class Error(Exception): - """Thrown If validator constraint is not satisfied.""" - - -class Validator(object): - """Base class for flags validators. - - Users should NOT overload these classes, and use gflags.Register... - methods instead. - """ - - # Used to assign each validator an unique insertion_index - validators_count = 0 - - def __init__(self, checker, message): - """Constructor to create all validators. - - Args: - checker: function to verify the constraint. - Input of this method varies, see SimpleValidator and - DictionaryValidator for a detailed description. - message: string, error message to be shown to the user - """ - self.checker = checker - self.message = message - Validator.validators_count += 1 - # Used to assert validators in the order they were registered (CL/18694236) - self.insertion_index = Validator.validators_count - - def Verify(self, flag_values): - """Verify that constraint is satisfied. - - flags library calls this method to verify Validator's constraint. - Args: - flag_values: gflags.FlagValues, containing all flags - Raises: - Error: if constraint is not satisfied. - """ - param = self._GetInputToCheckerFunction(flag_values) - if not self.checker(param): - raise Error(self.message) - - def GetFlagsNames(self): - """Return the names of the flags checked by this validator. - - Returns: - [string], names of the flags - """ - raise NotImplementedError('This method should be overloaded') - - def PrintFlagsWithValues(self, flag_values): - raise NotImplementedError('This method should be overloaded') - - def _GetInputToCheckerFunction(self, flag_values): - """Given flag values, construct the input to be given to checker. - - Args: - flag_values: gflags.FlagValues, containing all flags. - Returns: - Return type depends on the specific validator. - """ - raise NotImplementedError('This method should be overloaded') - - -class SimpleValidator(Validator): - """Validator behind RegisterValidator() method. - - Validates that a single flag passes its checker function. The checker function - takes the flag value and returns True (if value looks fine) or, if flag value - is not valid, either returns False or raises an Exception.""" - def __init__(self, flag_name, checker, message): - """Constructor. - - Args: - flag_name: string, name of the flag. - checker: function to verify the validator. - input - value of the corresponding flag (string, boolean, etc). - output - Boolean. Must return True if validator constraint is satisfied. - If constraint is not satisfied, it should either return False or - raise Error. - message: string, error message to be shown to the user if validator's - condition is not satisfied - """ - super(SimpleValidator, self).__init__(checker, message) - self.flag_name = flag_name - - def GetFlagsNames(self): - return [self.flag_name] - - def PrintFlagsWithValues(self, flag_values): - return 'flag --%s=%s' % (self.flag_name, flag_values[self.flag_name].value) - - def _GetInputToCheckerFunction(self, flag_values): - """Given flag values, construct the input to be given to checker. - - Args: - flag_values: gflags.FlagValues - Returns: - value of the corresponding flag. - """ - return flag_values[self.flag_name].value - - -class DictionaryValidator(Validator): - """Validator behind RegisterDictionaryValidator method. - - Validates that flag values pass their common checker function. The checker - function takes flag values and returns True (if values look fine) or, - if values are not valid, either returns False or raises an Exception. - """ - def __init__(self, flag_names, checker, message): - """Constructor. - - Args: - flag_names: [string], containing names of the flags used by checker. - checker: function to verify the validator. - input - dictionary, with keys() being flag_names, and value for each - key being the value of the corresponding flag (string, boolean, etc). - output - Boolean. Must return True if validator constraint is satisfied. - If constraint is not satisfied, it should either return False or - raise Error. - message: string, error message to be shown to the user if validator's - condition is not satisfied - """ - super(DictionaryValidator, self).__init__(checker, message) - self.flag_names = flag_names - - def _GetInputToCheckerFunction(self, flag_values): - """Given flag values, construct the input to be given to checker. - - Args: - flag_values: gflags.FlagValues - Returns: - dictionary, with keys() being self.lag_names, and value for each key - being the value of the corresponding flag (string, boolean, etc). - """ - return dict([key, flag_values[key].value] for key in self.flag_names) - - def PrintFlagsWithValues(self, flag_values): - prefix = 'flags ' - flags_with_values = [] - for key in self.flag_names: - flags_with_values.append('%s=%s' % (key, flag_values[key].value)) - return prefix + ', '.join(flags_with_values) - - def GetFlagsNames(self): - return self.flag_names diff --git a/gimaplib.py b/gimaplib.py deleted file mode 100644 index 1d508d1..0000000 --- a/gimaplib.py +++ /dev/null @@ -1,137 +0,0 @@ -# Functions that make IMAP behave more Gmail-ish - -import imaplib -import re -import shlex -import sys - -import gyb - -def GImapHasExtensions(imapconn): - ''' - Args: - imapconn: object, an authenticated IMAP connection - - Returns: - boolean, True if Gmail IMAP Extensions defined at: - http://code.google.com/apis/gmail/imap - are supported, False if not. - ''' - t, d = imapconn.capability() - if t != 'OK': - raise GImapHasExtensionsError('GImap Has Extensions could not check server capabilities: %s' % t) - return bool(d[0].count('X-GM-EXT-1')) - -def GImapSendID(imapconn, name, version, vendor, contact): - ''' - Args: - imapconn: object, an authenticated IMAP connection - name: string, IMAP Client Name - version: string, IMAP Client Version - vendor: string, IMAP Client Vendor - contact: string, email address of contact - - Returns: - list of IMAP Server ID response values - ''' - commands = {'ID' : ('AUTH',)} - imaplib.Commands.update(commands) - id = '("name" "%s" "version" "%s" "vendor" "%s" "contact" "%s")' % (name, version, vendor, contact) - t, d = imapconn._simple_command('ID', id) - r, d = imapconn._untagged_response(t, d, 'ID') - if r != 'OK': - raise GImapSendIDError('GImap Send ID failed to send ID: %s' % t) - return shlex.split(d[0][1:-1]) - -def ImapConnect(xoauth_string, debug): - imap_conn = imaplib.IMAP4_SSL('imap.gmail.com') - if debug: - imap_conn.debug = 4 - imap_conn.authenticate('XOAUTH2', lambda x: xoauth_string) - if not GImapHasExtensions(imap_conn): - print "This server does not support the Gmail IMAP Extensions." - sys.exit(1) - GImapSendID(imap_conn, gyb.__program_name__, gyb.__version__, gyb.__author__, gyb.__email__) - return imap_conn - -def GImapSearch(imapconn, gmail_search): - ''' - Args: - imapconn: object, an authenticated IMAP connection to a server supporting the X-GM-EXT1 IMAP capability (imap.gmail.com) - gmail_search: string, a typical Gmail search as defined at: - http://mail.google.com/support/bin/answer.py?answer=7190 - - Returns: - list, the IMAP UIDs of messages that match the search - - Note: Only the IMAP Selected folder is searched, it's as if 'in:' is appended to all searches. If you wish to search all mail, select '[Gmail]/All Mail' before performing the search. - ''' - #t, d = imapconn.search(None, 'X-GM-RAW', gmail_search) - gmail_search = gmail_search.replace('\\', '\\\\').replace('"', '\\"') - gmail_search = '"' + gmail_search + '"' - imapconn.literal = gmail_search - t, d = imapconn.uid('SEARCH', 'CHARSET', 'UTF-8', 'X-GM-RAW') - if t != 'OK': - raise GImapSearchError('GImap Search Failed: %s' % t) - return d[0].split() - -def GImapGetMessageLabels(imapconn, uid): - ''' - Args: - imapconn: object, an authenticated IMAP connection to a server supporting the X-GM-EXT1 IMAP capability (imap.gmail.com) - uid: int, the IMAP UID for the message whose labels you wish to learn. - - Returns: - list, the Gmail Labels of the message - ''' - t, d = imapconn.uid('FETCH', uid, '(X-GM-LABELS)') - if t != 'OK': - raise GImapGetMessageLabelsError('GImap Get Message Labels Failed: %s' % t) - if d[0] != None: - labels = re.search('^[0-9]* \(X-GM-LABELS \((.*?)\) UID %s\)' % uid, d[0]).group(1) - labels_list = shlex.split(labels) - else: - labels_list = [] - return labels_list - -def GImapSetMessageLabels(imapconn, uid, labels): - ''' - Args: - imapconn: object, an authenticated IMAP connection to a server supporting the X-GM-EXT1 IMAP capability (imap.gmail.com) - uid: int, the IMAP UID for the message whose labels you wish to learn. - labels: list, names of labels to be applied to the message - - Returns: - null on success or Error on failure - - Note: specified labels are added but the message's existing labels that are not specified are not removed. - ''' - labels_string = '"'+'" "'.join(labels)+'"' - t, d = imapconn.uid('STORE', uid, '+X-GM-LABELS', labels_string) - if t != 'OK': - print 'GImap Set Message Labels Failed: %s' % t - exit(33) - -def GImapGetFolders(imapconn): - ''' - Args: - imapconn: object, an authenticated IMAP connection - - Returns: - dictionary, Gmail special folder types mapped to their localized name - ''' - list_response_pattern = re.compile(r'\((?P.*?)\) "(?P.*)" (?P.*)') - for prefix in ['"[Gmail]/"', '"[Google Mail]/"', '""']: - t, d = imapconn.list(prefix, '*') - if t != 'OK': - raise GImapHasExtensionsError('GImap Get Folder could not check server LIST: %s' % t) - if d != [None]: - break - label_mappings = {} - for line in d: - flags, delimiter, label_local_name = list_response_pattern.match(line).groups() - flags_list = flags.split(' ') - for flag in flags_list: - if flag not in [u'\\HasNoChildren', u'\\HasChildren', u'\\Noinferiors']: - label_mappings[flag] = label_local_name[1:-1] - return label_mappings diff --git a/oauth2client/anyjson.py b/googleapiclient/__init__.py similarity index 52% rename from oauth2client/anyjson.py rename to googleapiclient/__init__.py index ae21c33..c000e97 100644 --- a/oauth2client/anyjson.py +++ b/googleapiclient/__init__.py @@ -1,4 +1,4 @@ -# Copyright (C) 2010 Google Inc. +# Copyright 2014 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,21 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utility module to import a JSON module - -Hides all the messy details of exactly where -we get a simplejson module from. -""" - -__author__ = 'jcgregorio@google.com (Joe Gregorio)' - - -try: # pragma: no cover - # Should work for Python2.6 and higher. - import json as simplejson -except ImportError: # pragma: no cover - try: - import simplejson - except ImportError: - # Try to import from django, should work on App Engine - from django.utils import simplejson +__version__ = "1.4.1" diff --git a/googleapiclient/channel.py b/googleapiclient/channel.py new file mode 100644 index 0000000..702186b --- /dev/null +++ b/googleapiclient/channel.py @@ -0,0 +1,287 @@ +"""Channel notifications support. + +Classes and functions to support channel subscriptions and notifications +on those channels. + +Notes: + - This code is based on experimental APIs and is subject to change. + - Notification does not do deduplication of notification ids, that's up to + the receiver. + - Storing the Channel between calls is up to the caller. + + +Example setting up a channel: + + # Create a new channel that gets notifications via webhook. + channel = new_webhook_channel("https://example.com/my_web_hook") + + # Store the channel, keyed by 'channel.id'. Store it before calling the + # watch method because notifications may start arriving before the watch + # method returns. + ... + + resp = service.objects().watchAll( + bucket="some_bucket_id", body=channel.body()).execute() + channel.update(resp) + + # Store the channel, keyed by 'channel.id'. Store it after being updated + # since the resource_id value will now be correct, and that's needed to + # stop a subscription. + ... + + +An example Webhook implementation using webapp2. Note that webapp2 puts +headers in a case insensitive dictionary, as headers aren't guaranteed to +always be upper case. + + id = self.request.headers[X_GOOG_CHANNEL_ID] + + # Retrieve the channel by id. + channel = ... + + # Parse notification from the headers, including validating the id. + n = notification_from_headers(channel, self.request.headers) + + # Do app specific stuff with the notification here. + if n.resource_state == 'sync': + # Code to handle sync state. + elif n.resource_state == 'exists': + # Code to handle the exists state. + elif n.resource_state == 'not_exists': + # Code to handle the not exists state. + + +Example of unsubscribing. + + service.channels().stop(channel.body()) +""" +from __future__ import absolute_import + +import datetime +import uuid + +from googleapiclient import errors +from oauth2client import util +import six + + +# The unix time epoch starts at midnight 1970. +EPOCH = datetime.datetime.utcfromtimestamp(0) + +# Map the names of the parameters in the JSON channel description to +# the parameter names we use in the Channel class. +CHANNEL_PARAMS = { + 'address': 'address', + 'id': 'id', + 'expiration': 'expiration', + 'params': 'params', + 'resourceId': 'resource_id', + 'resourceUri': 'resource_uri', + 'type': 'type', + 'token': 'token', + } + +X_GOOG_CHANNEL_ID = 'X-GOOG-CHANNEL-ID' +X_GOOG_MESSAGE_NUMBER = 'X-GOOG-MESSAGE-NUMBER' +X_GOOG_RESOURCE_STATE = 'X-GOOG-RESOURCE-STATE' +X_GOOG_RESOURCE_URI = 'X-GOOG-RESOURCE-URI' +X_GOOG_RESOURCE_ID = 'X-GOOG-RESOURCE-ID' + + +def _upper_header_keys(headers): + new_headers = {} + for k, v in six.iteritems(headers): + new_headers[k.upper()] = v + return new_headers + + +class Notification(object): + """A Notification from a Channel. + + Notifications are not usually constructed directly, but are returned + from functions like notification_from_headers(). + + Attributes: + message_number: int, The unique id number of this notification. + state: str, The state of the resource being monitored. + uri: str, The address of the resource being monitored. + resource_id: str, The unique identifier of the version of the resource at + this event. + """ + @util.positional(5) + def __init__(self, message_number, state, resource_uri, resource_id): + """Notification constructor. + + Args: + message_number: int, The unique id number of this notification. + state: str, The state of the resource being monitored. Can be one + of "exists", "not_exists", or "sync". + resource_uri: str, The address of the resource being monitored. + resource_id: str, The identifier of the watched resource. + """ + self.message_number = message_number + self.state = state + self.resource_uri = resource_uri + self.resource_id = resource_id + + +class Channel(object): + """A Channel for notifications. + + Usually not constructed directly, instead it is returned from helper + functions like new_webhook_channel(). + + Attributes: + type: str, The type of delivery mechanism used by this channel. For + example, 'web_hook'. + id: str, A UUID for the channel. + token: str, An arbitrary string associated with the channel that + is delivered to the target address with each event delivered + over this channel. + address: str, The address of the receiving entity where events are + delivered. Specific to the channel type. + expiration: int, The time, in milliseconds from the epoch, when this + channel will expire. + params: dict, A dictionary of string to string, with additional parameters + controlling delivery channel behavior. + resource_id: str, An opaque id that identifies the resource that is + being watched. Stable across different API versions. + resource_uri: str, The canonicalized ID of the watched resource. + """ + + @util.positional(5) + def __init__(self, type, id, token, address, expiration=None, + params=None, resource_id="", resource_uri=""): + """Create a new Channel. + + In user code, this Channel constructor will not typically be called + manually since there are functions for creating channels for each specific + type with a more customized set of arguments to pass. + + Args: + type: str, The type of delivery mechanism used by this channel. For + example, 'web_hook'. + id: str, A UUID for the channel. + token: str, An arbitrary string associated with the channel that + is delivered to the target address with each event delivered + over this channel. + address: str, The address of the receiving entity where events are + delivered. Specific to the channel type. + expiration: int, The time, in milliseconds from the epoch, when this + channel will expire. + params: dict, A dictionary of string to string, with additional parameters + controlling delivery channel behavior. + resource_id: str, An opaque id that identifies the resource that is + being watched. Stable across different API versions. + resource_uri: str, The canonicalized ID of the watched resource. + """ + self.type = type + self.id = id + self.token = token + self.address = address + self.expiration = expiration + self.params = params + self.resource_id = resource_id + self.resource_uri = resource_uri + + def body(self): + """Build a body from the Channel. + + Constructs a dictionary that's appropriate for passing into watch() + methods as the value of body argument. + + Returns: + A dictionary representation of the channel. + """ + result = { + 'id': self.id, + 'token': self.token, + 'type': self.type, + 'address': self.address + } + if self.params: + result['params'] = self.params + if self.resource_id: + result['resourceId'] = self.resource_id + if self.resource_uri: + result['resourceUri'] = self.resource_uri + if self.expiration: + result['expiration'] = self.expiration + + return result + + def update(self, resp): + """Update a channel with information from the response of watch(). + + When a request is sent to watch() a resource, the response returned + from the watch() request is a dictionary with updated channel information, + such as the resource_id, which is needed when stopping a subscription. + + Args: + resp: dict, The response from a watch() method. + """ + for json_name, param_name in six.iteritems(CHANNEL_PARAMS): + value = resp.get(json_name) + if value is not None: + setattr(self, param_name, value) + + +def notification_from_headers(channel, headers): + """Parse a notification from the webhook request headers, validate + the notification, and return a Notification object. + + Args: + channel: Channel, The channel that the notification is associated with. + headers: dict, A dictionary like object that contains the request headers + from the webhook HTTP request. + + Returns: + A Notification object. + + Raises: + errors.InvalidNotificationError if the notification is invalid. + ValueError if the X-GOOG-MESSAGE-NUMBER can't be converted to an int. + """ + headers = _upper_header_keys(headers) + channel_id = headers[X_GOOG_CHANNEL_ID] + if channel.id != channel_id: + raise errors.InvalidNotificationError( + 'Channel id mismatch: %s != %s' % (channel.id, channel_id)) + else: + message_number = int(headers[X_GOOG_MESSAGE_NUMBER]) + state = headers[X_GOOG_RESOURCE_STATE] + resource_uri = headers[X_GOOG_RESOURCE_URI] + resource_id = headers[X_GOOG_RESOURCE_ID] + return Notification(message_number, state, resource_uri, resource_id) + + +@util.positional(2) +def new_webhook_channel(url, token=None, expiration=None, params=None): + """Create a new webhook Channel. + + Args: + url: str, URL to post notifications to. + token: str, An arbitrary string associated with the channel that + is delivered to the target address with each notification delivered + over this channel. + expiration: datetime.datetime, A time in the future when the channel + should expire. Can also be None if the subscription should use the + default expiration. Note that different services may have different + limits on how long a subscription lasts. Check the response from the + watch() method to see the value the service has set for an expiration + time. + params: dict, Extra parameters to pass on channel creation. Currently + not used for webhook channels. + """ + expiration_ms = 0 + if expiration: + delta = expiration - EPOCH + expiration_ms = delta.microseconds/1000 + ( + delta.seconds + delta.days*24*3600)*1000 + if expiration_ms < 0: + expiration_ms = 0 + + return Channel('web_hook', str(uuid.uuid4()), + token, url, expiration=expiration_ms, + params=params) + diff --git a/apiclient/discovery.py b/googleapiclient/discovery.py similarity index 83% rename from apiclient/discovery.py rename to googleapiclient/discovery.py index 4c6cb60..4109865 100644 --- a/apiclient/discovery.py +++ b/googleapiclient/discovery.py @@ -1,4 +1,4 @@ -# Copyright (C) 2010 Google Inc. +# Copyright 2014 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,9 @@ A client library for Google's discovery based APIs. """ +from __future__ import absolute_import +import six +from six.moves import zip __author__ = 'jcgregorio@google.com (Joe Gregorio)' __all__ = [ @@ -25,44 +28,43 @@ 'key2param', ] +from six import StringIO +from six.moves.urllib.parse import urlencode, urlparse, urljoin, \ + urlunparse, parse_qsl # Standard library imports import copy +from email.generator import Generator from email.mime.multipart import MIMEMultipart from email.mime.nonmultipart import MIMENonMultipart +import json import keyword import logging import mimetypes import os import re -import urllib -import urlparse - -try: - from urlparse import parse_qsl -except ImportError: - from cgi import parse_qsl # Third-party imports import httplib2 -import mimeparse import uritemplate # Local imports -from apiclient.errors import HttpError -from apiclient.errors import InvalidJsonError -from apiclient.errors import MediaUploadSizeError -from apiclient.errors import UnacceptableMimeTypeError -from apiclient.errors import UnknownApiNameOrVersion -from apiclient.errors import UnknownFileType -from apiclient.http import HttpRequest -from apiclient.http import MediaFileUpload -from apiclient.http import MediaUpload -from apiclient.model import JsonModel -from apiclient.model import MediaModel -from apiclient.model import RawModel -from apiclient.schema import Schemas -from oauth2client.anyjson import simplejson +from googleapiclient import mimeparse +from googleapiclient.errors import HttpError +from googleapiclient.errors import InvalidJsonError +from googleapiclient.errors import MediaUploadSizeError +from googleapiclient.errors import UnacceptableMimeTypeError +from googleapiclient.errors import UnknownApiNameOrVersion +from googleapiclient.errors import UnknownFileType +from googleapiclient.http import BatchHttpRequest +from googleapiclient.http import HttpRequest +from googleapiclient.http import MediaFileUpload +from googleapiclient.http import MediaUpload +from googleapiclient.model import JsonModel +from googleapiclient.model import MediaModel +from googleapiclient.model import RawModel +from googleapiclient.schema import Schemas +from oauth2client.client import GoogleCredentials from oauth2client.util import _add_query_parameter from oauth2client.util import positional @@ -146,7 +148,8 @@ def build(serviceName, discoveryServiceUrl=DISCOVERY_URI, developerKey=None, model=None, - requestBuilder=HttpRequest): + requestBuilder=HttpRequest, + credentials=None): """Construct a Resource for interacting with an API. Construct a Resource object for interacting with an API. The serviceName and @@ -163,9 +166,11 @@ def build(serviceName, document for that service. developerKey: string, key obtained from https://code.google.com/apis/console. - model: apiclient.Model, converts to and from the wire format. - requestBuilder: apiclient.http.HttpRequest, encapsulator for an HTTP + model: googleapiclient.Model, converts to and from the wire format. + requestBuilder: googleapiclient.http.HttpRequest, encapsulator for an HTTP request. + credentials: oauth2client.Credentials, credentials to be used for + authentication. Returns: A Resource object with methods for interacting with the service. @@ -187,7 +192,7 @@ def build(serviceName, if 'REMOTE_ADDR' in os.environ: requested_url = _add_query_parameter(requested_url, 'userIp', os.environ['REMOTE_ADDR']) - logger.info('URL being requested: %s' % requested_url) + logger.info('URL being requested: GET %s' % requested_url) resp, content = http.request(requested_url) @@ -198,13 +203,19 @@ def build(serviceName, raise HttpError(resp, content, uri=requested_url) try: - service = simplejson.loads(content) - except ValueError, e: + content = content.decode('utf-8') + except AttributeError: + pass + + try: + service = json.loads(content) + except ValueError as e: logger.error('Failed to parse as JSON: ' + content) raise InvalidJsonError() return build_from_document(content, base=discoveryServiceUrl, http=http, - developerKey=developerKey, model=model, requestBuilder=requestBuilder) + developerKey=developerKey, model=model, requestBuilder=requestBuilder, + credentials=credentials) @positional(1) @@ -215,7 +226,8 @@ def build_from_document( http=None, developerKey=None, model=None, - requestBuilder=HttpRequest): + requestBuilder=HttpRequest, + credentials=None): """Create a Resource for interacting with an API. Same as `build()`, but constructs the Resource object from a discovery @@ -236,6 +248,7 @@ def build_from_document( model: Model class instance that serializes and de-serializes requests and responses. requestBuilder: Takes an http request and packages it up to be executed. + credentials: object, credentials to be used for authentication. Returns: A Resource object with methods for interacting with the service. @@ -244,11 +257,33 @@ def build_from_document( # future is no longer used. future = {} - if isinstance(service, basestring): - service = simplejson.loads(service) - base = urlparse.urljoin(service['rootUrl'], service['servicePath']) + if isinstance(service, six.string_types): + service = json.loads(service) + base = urljoin(service['rootUrl'], service['servicePath']) schema = Schemas(service) + if credentials: + # If credentials were passed in, we could have two cases: + # 1. the scopes were specified, in which case the given credentials + # are used for authorizing the http; + # 2. the scopes were not provided (meaning the Application Default + # Credentials are to be used). In this case, the Application Default + # Credentials are built and used instead of the original credentials. + # If there are no scopes found (meaning the given service requires no + # authentication), there is no authorization of the http. + if (isinstance(credentials, GoogleCredentials) and + credentials.create_scoped_required()): + scopes = service.get('auth', {}).get('oauth2', {}).get('scopes', {}) + if scopes: + credentials = credentials.create_scoped(list(scopes.keys())) + else: + # No need to authorize the http object + # if the service does not require authentication. + credentials = None + + if credentials: + http = credentials.authorize(http) + if model is None: features = service.get('features', []) model = JsonModel('dataWrapper' in features) @@ -298,13 +333,13 @@ def _media_size_to_long(maxSize): The size as an integer value. """ if len(maxSize) < 2: - return 0L + return 0 units = maxSize[-2:].upper() bit_shift = _MEDIA_SIZE_BIT_SHIFTS.get(units) if bit_shift is not None: - return long(maxSize[:-2]) << bit_shift + return int(maxSize[:-2]) << bit_shift else: - return long(maxSize) + return int(maxSize) def _media_path_url_from_info(root_desc, path_url): @@ -354,7 +389,7 @@ def _fix_up_parameters(method_desc, root_desc, http_method): parameters = method_desc.setdefault('parameters', {}) # Add in the parameters common to all methods. - for name, description in root_desc.get('parameters', {}).iteritems(): + for name, description in six.iteritems(root_desc.get('parameters', {})): parameters[name] = description # Add in undocumented query parameters. @@ -460,6 +495,23 @@ def _fix_up_method_description(method_desc, root_desc): return path_url, http_method, method_id, accept, max_size, media_path_url +def _urljoin(base, url): + """Custom urljoin replacement supporting : before / in url.""" + # In general, it's unsafe to simply join base and url. However, for + # the case of discovery documents, we know: + # * base will never contain params, query, or fragment + # * url will never contain a scheme or net_loc. + # In general, this means we can safely join on /; we just need to + # ensure we end up with precisely one / joining base and url. The + # exception here is the case of media uploads, where url will be an + # absolute url. + if url.startswith('http://') or url.startswith('https://'): + return urljoin(base, url) + new_base = base if base.endswith('/') else base + '/' + new_url = url[1:] if url.startswith('/') else url + return new_base + new_url + + # TODO(dhermes): Convert this class to ResourceMethod and make it callable class ResourceMethodParameters(object): """Represents the parameters associated with a method. @@ -520,7 +572,7 @@ def set_parameters(self, method_desc): comes from the dictionary of methods stored in the 'methods' key in the deserialized discovery document. """ - for arg, desc in method_desc.get('parameters', {}).iteritems(): + for arg, desc in six.iteritems(method_desc.get('parameters', {})): param = key2param(arg) self.argmap[param] = arg @@ -568,12 +620,12 @@ def createMethod(methodName, methodDesc, rootDesc, schema): def method(self, **kwargs): # Don't bother with doc string, it will be over-written by createMethod. - for name in kwargs.iterkeys(): + for name in six.iterkeys(kwargs): if name not in parameters.argmap: raise TypeError('Got an unexpected keyword argument "%s"' % name) # Remove args that have a value of None. - keys = kwargs.keys() + keys = list(kwargs.keys()) for name in keys: if kwargs[name] is None: del kwargs[name] @@ -582,9 +634,9 @@ def method(self, **kwargs): if name not in kwargs: raise TypeError('Missing required parameter "%s"' % name) - for name, regex in parameters.pattern_params.iteritems(): + for name, regex in six.iteritems(parameters.pattern_params): if name in kwargs: - if isinstance(kwargs[name], basestring): + if isinstance(kwargs[name], six.string_types): pvalues = [kwargs[name]] else: pvalues = kwargs[name] @@ -594,13 +646,13 @@ def method(self, **kwargs): 'Parameter "%s" value "%s" does not match the pattern "%s"' % (name, pvalue, regex)) - for name, enums in parameters.enum_params.iteritems(): + for name, enums in six.iteritems(parameters.enum_params): if name in kwargs: # We need to handle the case of a repeated enum # name differently, since we want to handle both # arg='value' and arg=['value1', 'value2'] if (name in parameters.repeated_params and - not isinstance(kwargs[name], basestring)): + not isinstance(kwargs[name], six.string_types)): values = kwargs[name] else: values = [kwargs[name]] @@ -612,7 +664,7 @@ def method(self, **kwargs): actual_query_params = {} actual_path_params = {} - for key, value in kwargs.iteritems(): + for key, value in six.iteritems(kwargs): to_type = parameters.param_types.get(key, 'string') # For repeated parameters we cast each member of the list. if key in parameters.repeated_params and type(value) == type([]): @@ -640,14 +692,14 @@ def method(self, **kwargs): actual_path_params, actual_query_params, body_value) expanded_url = uritemplate.expand(pathUrl, params) - url = urlparse.urljoin(self._baseUrl, expanded_url + query) + url = _urljoin(self._baseUrl, expanded_url + query) resumable = None multipart_boundary = '' if media_filename: # Ensure we end up with a valid MediaUpload object. - if isinstance(media_filename, basestring): + if isinstance(media_filename, six.string_types): (media_mime_type, encoding) = mimetypes.guess_type(media_filename) if media_mime_type is None: raise UnknownFileType(media_filename) @@ -661,12 +713,12 @@ def method(self, **kwargs): raise TypeError('media_filename must be str or MediaUpload.') # Check the maxSize - if maxSize > 0 and media_upload.size() > maxSize: + if media_upload.size() is not None and media_upload.size() > maxSize > 0: raise MediaUploadSizeError("Media larger than: %s" % maxSize) # Use the media path uri for media uploads expanded_url = uritemplate.expand(mediaPathUrl, params) - url = urlparse.urljoin(self._baseUrl, expanded_url + query) + url = _urljoin(self._baseUrl, expanded_url + query) if media_upload.resumable(): url = _add_query_parameter(url, 'uploadType', 'resumable') @@ -699,14 +751,19 @@ def method(self, **kwargs): payload = media_upload.getbytes(0, media_upload.size()) msg.set_payload(payload) msgRoot.attach(msg) - body = msgRoot.as_string() + # encode the body: note that we can't use `as_string`, because + # it plays games with `From ` lines. + fp = StringIO() + g = Generator(fp, mangle_from_=False) + g.flatten(msgRoot, unixfrom=False) + body = fp.getvalue() multipart_boundary = msgRoot.get_boundary() headers['content-type'] = ('multipart/related; ' 'boundary="%s"') % multipart_boundary url = _add_query_parameter(url, 'uploadType', 'multipart') - logger.info('URL being requested: %s' % url) + logger.info('URL being requested: %s %s' % (httpMethod,url)) return self._requestBuilder(self._http, model.response, url, @@ -721,10 +778,10 @@ def method(self, **kwargs): docs.append('Args:\n') # Skip undocumented params and params common to all methods. - skip_parameters = rootDesc.get('parameters', {}).keys() + skip_parameters = list(rootDesc.get('parameters', {}).keys()) skip_parameters.extend(STACK_QUERY_PARAMETERS) - all_args = parameters.argmap.keys() + all_args = list(parameters.argmap.keys()) args_ordered = [key2param(s) for s in methodDesc.get('parameterOrder', [])] # Move body to the front of the line. @@ -803,18 +860,18 @@ def methodNext(self, previous_request, previous_response): request = copy.copy(previous_request) pageToken = previous_response['nextPageToken'] - parsed = list(urlparse.urlparse(request.uri)) + parsed = list(urlparse(request.uri)) q = parse_qsl(parsed[4]) # Find and remove old 'pageToken' value from URI newq = [(key, value) for (key, value) in q if key != 'pageToken'] newq.append(('pageToken', pageToken)) - parsed[4] = urllib.urlencode(newq) - uri = urlparse.urlunparse(parsed) + parsed[4] = urlencode(newq) + uri = urlunparse(parsed) request.uri = uri - logger.info('URL being requested: %s' % uri) + logger.info('URL being requested: %s %s' % (methodName,uri)) return request @@ -832,9 +889,9 @@ def __init__(self, http, baseUrl, model, requestBuilder, developerKey, http: httplib2.Http, Object to make http requests with. baseUrl: string, base URL for the API. All requests are relative to this URI. - model: apiclient.Model, converts to and from the wire format. + model: googleapiclient.Model, converts to and from the wire format. requestBuilder: class or callable that instantiates an - apiclient.HttpRequest object. + googleapiclient.HttpRequest object. developerKey: string, key obtained from https://code.google.com/apis/console resourceDesc: object, section of deserialized discovery document that @@ -894,9 +951,30 @@ def _set_service_methods(self): self._add_next_methods(self._resourceDesc, self._schema) def _add_basic_methods(self, resourceDesc, rootDesc, schema): + # If this is the root Resource, add a new_batch_http_request() method. + if resourceDesc == rootDesc: + batch_uri = '%s%s' % ( + rootDesc['rootUrl'], rootDesc.get('batchPath', 'batch')) + def new_batch_http_request(callback=None): + """Create a BatchHttpRequest object based on the discovery document. + + Args: + callback: callable, A callback to be called for each response, of the + form callback(id, response, exception). The first parameter is the + request id, and the second is the deserialized response object. The + third is an apiclient.errors.HttpError exception object if an HTTP + error occurred while processing the request, or None if no error + occurred. + + Returns: + A BatchHttpRequest object based on the discovery document. + """ + return BatchHttpRequest(callback=callback, batch_uri=batch_uri) + self._set_dynamic_attr('new_batch_http_request', new_batch_http_request) + # Add basic methods to Resource if 'methods' in resourceDesc: - for methodName, methodDesc in resourceDesc['methods'].iteritems(): + for methodName, methodDesc in six.iteritems(resourceDesc['methods']): fixedMethodName, method = createMethod( methodName, methodDesc, rootDesc, schema) self._set_dynamic_attr(fixedMethodName, @@ -935,7 +1013,7 @@ def methodResource(self): return (methodName, methodResource) - for methodName, methodDesc in resourceDesc['resources'].iteritems(): + for methodName, methodDesc in six.iteritems(resourceDesc['resources']): fixedMethodName, method = createResourceMethod(methodName, methodDesc) self._set_dynamic_attr(fixedMethodName, method.__get__(self, self.__class__)) @@ -945,7 +1023,7 @@ def _add_next_methods(self, resourceDesc, schema): # Look for response bodies in schema that contain nextPageToken, and methods # that take a pageToken parameter. if 'methods' in resourceDesc: - for methodName, methodDesc in resourceDesc['methods'].iteritems(): + for methodName, methodDesc in six.iteritems(resourceDesc['methods']): if 'response' in methodDesc: responseSchema = methodDesc['response'] if '$ref' in responseSchema: diff --git a/apiclient/errors.py b/googleapiclient/errors.py similarity index 90% rename from apiclient/errors.py rename to googleapiclient/errors.py index 2bf9149..3d44de7 100644 --- a/apiclient/errors.py +++ b/googleapiclient/errors.py @@ -1,6 +1,4 @@ -#!/usr/bin/python2.4 -# -# Copyright (C) 2010 Google Inc. +# Copyright 2014 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,12 +17,13 @@ All exceptions defined by the library should be defined in this file. """ +from __future__ import absolute_import __author__ = 'jcgregorio@google.com (Joe Gregorio)' +import json from oauth2client import util -from oauth2client.anyjson import simplejson class Error(Exception): @@ -38,6 +37,8 @@ class HttpError(Error): @util.positional(3) def __init__(self, resp, content, uri=None): self.resp = resp + if not isinstance(content, bytes): + raise TypeError("HTTP content should be bytes") self.content = content self.uri = uri @@ -45,7 +46,7 @@ def _get_reason(self): """Calculate the reason for the error from the response content.""" reason = self.resp.reason try: - data = simplejson.loads(self.content) + data = json.loads(self.content.decode('utf-8')) reason = data['error']['message'] except (ValueError, KeyError): pass @@ -102,6 +103,9 @@ class InvalidChunkSizeError(Error): """The given chunksize is not valid.""" pass +class InvalidNotificationError(Error): + """The channel Notification is invalid.""" + pass class BatchError(HttpError): """Error occured during batch operations.""" diff --git a/apiclient/http.py b/googleapiclient/http.py similarity index 83% rename from apiclient/http.py rename to googleapiclient/http.py index a956477..9ddc6e5 100644 --- a/apiclient/http.py +++ b/googleapiclient/http.py @@ -1,4 +1,4 @@ -# Copyright (C) 2012 Google Inc. +# Copyright 2014 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,35 +18,42 @@ object supporting an execute() method that does the actuall HTTP request. """ +from __future__ import absolute_import +import six +from six.moves import range __author__ = 'jcgregorio@google.com (Joe Gregorio)' -import StringIO +from six import BytesIO, StringIO +from six.moves.urllib.parse import urlparse, urlunparse, quote, unquote + import base64 import copy import gzip import httplib2 -import mimeparse +import json +import logging import mimetypes import os +import random import sys -import urllib -import urlparse +import time import uuid from email.generator import Generator from email.mime.multipart import MIMEMultipart from email.mime.nonmultipart import MIMENonMultipart from email.parser import FeedParser -from errors import BatchError -from errors import HttpError -from errors import InvalidChunkSizeError -from errors import ResumableUploadError -from errors import UnexpectedBodyError -from errors import UnexpectedMethodError -from model import JsonModel + +from googleapiclient import mimeparse +from googleapiclient.errors import BatchError +from googleapiclient.errors import HttpError +from googleapiclient.errors import InvalidChunkSizeError +from googleapiclient.errors import ResumableUploadError +from googleapiclient.errors import UnexpectedBodyError +from googleapiclient.errors import UnexpectedMethodError +from googleapiclient.model import JsonModel from oauth2client import util -from oauth2client.anyjson import simplejson DEFAULT_CHUNK_SIZE = 512*1024 @@ -218,7 +225,7 @@ def _to_json(self, strip=None): del d[member] d['_class'] = t.__name__ d['_module'] = t.__module__ - return simplejson.dumps(d) + return json.dumps(d) def to_json(self): """Create a JSON representation of an instance of MediaUpload. @@ -241,7 +248,7 @@ def new_from_json(cls, s): An instance of the subclass of MediaUpload that was serialized with to_json(). """ - data = simplejson.loads(s) + data = json.loads(s) # Find and call the right classmethod from_json() to restore the object. module = data['_module'] m = __import__(module, fromlist=module.split('.')[:-1]) @@ -256,7 +263,7 @@ class MediaIoBaseUpload(MediaUpload): Note that the Python file object is compatible with io.Base and can be used with this class also. - fh = io.BytesIO('...Some data to upload...') + fh = BytesIO('...Some data to upload...') media = MediaIoBaseUpload(fh, mimetype='image/png', chunksize=1024*1024, resumable=True) farm.animals().insert( @@ -433,7 +440,7 @@ def to_json(self): @staticmethod def from_json(s): - d = simplejson.loads(s) + d = json.loads(s) return MediaFileUpload(d['_filename'], mimetype=d['_mimetype'], chunksize=d['_chunksize'], resumable=d['_resumable']) @@ -462,7 +469,7 @@ def __init__(self, body, mimetype='application/octet-stream', resumable: bool, True if this is a resumable upload. False means upload in a single request. """ - fd = StringIO.StringIO(body) + fd = BytesIO(body) super(MediaInMemoryUpload, self).__init__(fd, mimetype, chunksize=chunksize, resumable=resumable) @@ -494,7 +501,7 @@ def __init__(self, fd, request, chunksize=DEFAULT_CHUNK_SIZE): Args: fd: io.Base or file object, The stream in which to write the downloaded bytes. - request: apiclient.http.HttpRequest, the media request to perform in + request: googleapiclient.http.HttpRequest, the media request to perform in chunks. chunksize: int, File will be downloaded in chunks of this many bytes. """ @@ -506,16 +513,27 @@ def __init__(self, fd, request, chunksize=DEFAULT_CHUNK_SIZE): self._total_size = None self._done = False - def next_chunk(self): + # Stubs for testing. + self._sleep = time.sleep + self._rand = random.random + + @util.positional(1) + def next_chunk(self, num_retries=0): """Get the next chunk of the download. + Args: + num_retries: Integer, number of times to retry 500's with randomized + exponential backoff. If all retries fail, the raised HttpError + represents the last request. If zero (default), we attempt the + request only once. + Returns: (status, done): (MediaDownloadStatus, boolean) The value of 'done' will be True when the media has been fully downloaded. Raises: - apiclient.errors.HttpError if the response was not a 2xx. + googleapiclient.errors.HttpError if the response was not a 2xx. httplib2.HttpLib2Error if a transport error has occured. """ headers = { @@ -523,13 +541,21 @@ def next_chunk(self): self._progress, self._progress + self._chunksize) } http = self._request.http - http.follow_redirects = False - resp, content = http.request(self._uri, headers=headers) - if resp.status in [301, 302, 303, 307, 308] and 'location' in resp: - self._uri = resp['location'] - resp, content = http.request(self._uri, headers=headers) + for retry_num in range(num_retries + 1): + if retry_num > 0: + self._sleep(self._rand() * 2**retry_num) + logging.warning( + 'Retry #%d for media download: GET %s, following status: %d' + % (retry_num, self._uri, resp.status)) + + resp, content = http.request(self._uri, headers=headers) + if resp.status < 500: + break + if resp.status in [200, 206]: + if 'content-location' in resp and resp['content-location'] != self._uri: + self._uri = resp['content-location'] self._progress += len(content) self._fd.write(content) @@ -537,6 +563,8 @@ def next_chunk(self): content_range = resp['content-range'] length = content_range.rsplit('/', 1)[1] self._total_size = int(length) + elif 'content-length' in resp: + self._total_size = int(resp['content-length']) if self._progress == self._total_size: self._done = True @@ -633,51 +661,72 @@ def __init__(self, http, postproc, uri, # The bytes that have been uploaded. self.resumable_progress = 0 + # Stubs for testing. + self._rand = random.random + self._sleep = time.sleep + @util.positional(1) - def execute(self, http=None): + def execute(self, http=None, num_retries=0): """Execute the request. Args: http: httplib2.Http, an http object to be used in place of the one the HttpRequest request object was constructed with. + num_retries: Integer, number of times to retry 500's with randomized + exponential backoff. If all retries fail, the raised HttpError + represents the last request. If zero (default), we attempt the + request only once. Returns: A deserialized object model of the response body as determined by the postproc. Raises: - apiclient.errors.HttpError if the response was not a 2xx. + googleapiclient.errors.HttpError if the response was not a 2xx. httplib2.HttpLib2Error if a transport error has occured. """ if http is None: http = self.http + if self.resumable: body = None while body is None: - _, body = self.next_chunk(http=http) + _, body = self.next_chunk(http=http, num_retries=num_retries) return body - else: - if 'content-length' not in self.headers: - self.headers['content-length'] = str(self.body_size) - # If the request URI is too long then turn it into a POST request. - if len(self.uri) > MAX_URI_LENGTH and self.method == 'GET': - self.method = 'POST' - self.headers['x-http-method-override'] = 'GET' - self.headers['content-type'] = 'application/x-www-form-urlencoded' - parsed = urlparse.urlparse(self.uri) - self.uri = urlparse.urlunparse( - (parsed.scheme, parsed.netloc, parsed.path, parsed.params, None, - None) - ) - self.body = parsed.query - self.headers['content-length'] = str(len(self.body)) + + # Non-resumable case. + + if 'content-length' not in self.headers: + self.headers['content-length'] = str(self.body_size) + # If the request URI is too long then turn it into a POST request. + if len(self.uri) > MAX_URI_LENGTH and self.method == 'GET': + self.method = 'POST' + self.headers['x-http-method-override'] = 'GET' + self.headers['content-type'] = 'application/x-www-form-urlencoded' + parsed = urlparse(self.uri) + self.uri = urlunparse( + (parsed.scheme, parsed.netloc, parsed.path, parsed.params, None, + None) + ) + self.body = parsed.query + self.headers['content-length'] = str(len(self.body)) + + # Handle retries for server-side errors. + for retry_num in range(num_retries + 1): + if retry_num > 0: + self._sleep(self._rand() * 2**retry_num) + logging.warning('Retry #%d for request: %s %s, following status: %d' + % (retry_num, self.method, self.uri, resp.status)) resp, content = http.request(str(self.uri), method=str(self.method), body=self.body, headers=self.headers) - for callback in self.response_callbacks: - callback(resp) - if resp.status >= 300: - raise HttpError(resp, content, uri=self.uri) + if resp.status < 500: + break + + for callback in self.response_callbacks: + callback(resp) + if resp.status >= 300: + raise HttpError(resp, content, uri=self.uri) return self.postproc(resp, content) @util.positional(2) @@ -693,7 +742,7 @@ def cb(resp): self.response_callbacks.append(cb) @util.positional(1) - def next_chunk(self, http=None): + def next_chunk(self, http=None, num_retries=0): """Execute the next step of a resumable upload. Can only be used if the method being executed supports media uploads and @@ -715,12 +764,20 @@ def next_chunk(self, http=None): print "Upload %d%% complete." % int(status.progress() * 100) + Args: + http: httplib2.Http, an http object to be used in place of the + one the HttpRequest request object was constructed with. + num_retries: Integer, number of times to retry 500's with randomized + exponential backoff. If all retries fail, the raised HttpError + represents the last request. If zero (default), we attempt the + request only once. + Returns: (status, body): (ResumableMediaStatus, object) The body will be None until the resumable media is fully uploaded. Raises: - apiclient.errors.HttpError if the response was not a 2xx. + googleapiclient.errors.HttpError if the response was not a 2xx. httplib2.HttpLib2Error if a transport error has occured. """ if http is None: @@ -738,9 +795,19 @@ def next_chunk(self, http=None): start_headers['X-Upload-Content-Length'] = size start_headers['content-length'] = str(self.body_size) - resp, content = http.request(self.uri, self.method, - body=self.body, - headers=start_headers) + for retry_num in range(num_retries + 1): + if retry_num > 0: + self._sleep(self._rand() * 2**retry_num) + logging.warning( + 'Retry #%d for resumable URI request: %s %s, following status: %d' + % (retry_num, self.method, self.uri, resp.status)) + + resp, content = http.request(self.uri, method=self.method, + body=self.body, + headers=start_headers) + if resp.status < 500: + break + if resp.status == 200 and 'location' in resp: self.resumable_uri = resp['location'] else: @@ -792,13 +859,23 @@ def next_chunk(self, http=None): # calculate the size when working with _StreamSlice. 'Content-Length': str(chunk_end - self.resumable_progress + 1) } - try: - resp, content = http.request(self.resumable_uri, 'PUT', - body=data, - headers=headers) - except: - self._in_error_state = True - raise + + for retry_num in range(num_retries + 1): + if retry_num > 0: + self._sleep(self._rand() * 2**retry_num) + logging.warning( + 'Retry #%d for media upload: %s %s, following status: %d' + % (retry_num, self.method, self.uri, resp.status)) + + try: + resp, content = http.request(self.resumable_uri, method='PUT', + body=data, + headers=headers) + except: + self._in_error_state = True + raise + if resp.status < 500: + break return self._process_response(resp, content) @@ -814,7 +891,7 @@ def _process_response(self, resp, content): The body will be None until the resumable media is fully uploaded. Raises: - apiclient.errors.HttpError if the response was not a 2xx or a 308. + googleapiclient.errors.HttpError if the response was not a 2xx or a 308. """ if resp.status in [200, 201]: self._in_error_state = False @@ -839,13 +916,15 @@ def to_json(self): d['resumable'] = self.resumable.to_json() del d['http'] del d['postproc'] + del d['_sleep'] + del d['_rand'] - return simplejson.dumps(d) + return json.dumps(d) @staticmethod def from_json(s, http, postproc): """Returns an HttpRequest populated with info from a JSON object.""" - d = simplejson.loads(s) + d = json.loads(s) if d['resumable'] is not None: d['resumable'] = MediaUpload.new_from_json(d['resumable']) return HttpRequest( @@ -863,7 +942,7 @@ class BatchHttpRequest(object): """Batches multiple HttpRequest objects into a single HTTP request. Example: - from apiclient.http import BatchHttpRequest + from googleapiclient.http import BatchHttpRequest def list_animals(request_id, response, exception): \"\"\"Do something with the animals list response.\"\"\" @@ -900,7 +979,7 @@ def __init__(self, callback=None, batch_uri=None): callback: callable, A callback to be called for each response, of the form callback(id, response, exception). The first parameter is the request id, and the second is the deserialized response object. The - third is an apiclient.errors.HttpError exception object if an HTTP error + third is an googleapiclient.errors.HttpError exception object if an HTTP error occurred while processing the request, or None if no error occurred. batch_uri: string, URI to send batch requests to. """ @@ -973,7 +1052,7 @@ def _id_to_header(self, id_): if self._base_id is None: self._base_id = uuid.uuid4() - return '<%s+%s>' % (self._base_id, urllib.quote(id_)) + return '<%s+%s>' % (self._base_id, quote(id_)) def _header_to_id(self, header): """Convert a Content-ID header value to an id. @@ -996,7 +1075,7 @@ def _header_to_id(self, header): raise BatchError("Invalid value for Content-ID: %s" % header) base, id_ = header[1:-1].rsplit('+', 1) - return urllib.unquote(id_) + return unquote(id_) def _serialize_request(self, request): """Convert an HttpRequest object into a string. @@ -1008,9 +1087,9 @@ def _serialize_request(self, request): The request as a string in application/http format. """ # Construct status line - parsed = urlparse.urlparse(request.uri) - request_line = urlparse.urlunparse( - (None, None, parsed.path, parsed.params, parsed.query, None) + parsed = urlparse(request.uri) + request_line = urlunparse( + ('', '', parsed.path, parsed.params, parsed.query, '') ) status_line = request.method + ' ' + request_line + ' HTTP/1.1\n' major, minor = request.headers.get('content-type', 'application/json').split('/') @@ -1025,7 +1104,7 @@ def _serialize_request(self, request): if 'content-type' in headers: del headers['content-type'] - for key, value in headers.iteritems(): + for key, value in six.iteritems(headers): msg[key] = value msg['Host'] = parsed.netloc msg.set_unixfrom(None) @@ -1035,17 +1114,13 @@ def _serialize_request(self, request): msg['content-length'] = str(len(request.body)) # Serialize the mime message. - fp = StringIO.StringIO() + fp = StringIO() # maxheaderlen=0 means don't line wrap headers. g = Generator(fp, maxheaderlen=0) g.flatten(msg, unixfrom=False) body = fp.getvalue() - # Strip off the \n\n that the MIME lib tacks onto the end of the payload. - if request.body is None: - body = body[:-2] - - return status_line.encode('utf-8') + body + return status_line + body def _deserialize_response(self, payload): """Convert string into httplib2 response and content. @@ -1105,7 +1180,7 @@ def add(self, request, callback=None, request_id=None): callback: callable, A callback to be called for this response, of the form callback(id, response, exception). The first parameter is the request id, and the second is the deserialized response object. The - third is an apiclient.errors.HttpError exception object if an HTTP error + third is an googleapiclient.errors.HttpError exception object if an HTTP error occurred while processing the request, or None if no errors occurred. request_id: string, A unique id for the request. The id will be passed to the callback with the response. @@ -1138,7 +1213,7 @@ def _execute(self, http, order, requests): Raises: httplib2.HttpLib2Error if a transport error has occured. - apiclient.errors.BatchError if the response is the wrong format. + googleapiclient.errors.BatchError if the response is the wrong format. """ message = MIMEMultipart('mixed') # Message should not write out it's own headers. @@ -1156,23 +1231,29 @@ def _execute(self, http, order, requests): msg.set_payload(body) message.attach(msg) - body = message.as_string() + # encode the body: note that we can't use `as_string`, because + # it plays games with `From ` lines. + fp = StringIO() + g = Generator(fp, mangle_from_=False) + g.flatten(message, unixfrom=False) + body = fp.getvalue() headers = {} headers['content-type'] = ('multipart/mixed; ' 'boundary="%s"') % message.get_boundary() - resp, content = http.request(self._batch_uri, 'POST', body=body, + resp, content = http.request(self._batch_uri, method='POST', body=body, headers=headers) if resp.status >= 300: raise HttpError(resp, content, uri=self._batch_uri) - # Now break out the individual responses and store each one. - boundary, _ = content.split(None, 1) - # Prepend with a content-type header so FeedParser can handle it. header = 'content-type: %s\r\n\r\n' % resp['content-type'] + # PY3's FeedParser only accepts unicode. So we should decode content + # here, and encode each payload again. + if six.PY3: + content = content.decode('utf-8') for_parser = header + content parser = FeedParser() @@ -1186,6 +1267,9 @@ def _execute(self, http, order, requests): for part in mime_response.get_payload(): request_id = self._header_to_id(part['Content-ID']) response, content = self._deserialize_response(part.get_payload()) + # We encode content here to emulate normal http response. + if isinstance(content, six.text_type): + content = content.encode('utf-8') self._responses[request_id] = (response, content) @util.positional(1) @@ -1202,7 +1286,7 @@ def execute(self, http=None): Raises: httplib2.HttpLib2Error if a transport error has occured. - apiclient.errors.BatchError if the response is the wrong format. + googleapiclient.errors.BatchError if the response is the wrong format. """ # If http is not supplied use the first valid one given in the requests. @@ -1216,23 +1300,60 @@ def execute(self, http=None): if http is None: raise ValueError("Missing a valid http object.") - self._execute(http, self._order, self._requests) + #self._execute(http, self._order, self._requests) # Loop over all the requests and check for 401s. For each 401 request the # credentials should be refreshed and then sent again in a separate batch. - redo_requests = {} - redo_order = [] - - for request_id in self._order: - resp, content = self._responses[request_id] - if resp['status'] == '401': - redo_order.append(request_id) - request = self._requests[request_id] - self._refresh_and_apply_credentials(request, http) - redo_requests[request_id] = request - - if redo_requests: - self._execute(http, redo_order, redo_requests) + #redo_requests = {} + #redo_order = [] + + #for request_id in self._order: + # resp, content = self._responses[request_id] + # if resp['status'] == '401': + # redo_order.append(request_id) + # request = self._requests[request_id] + # self._refresh_and_apply_credentials(request, http) + # redo_requests[request_id] = request + + #if redo_requests: + # self._execute(http, redo_order, redo_requests) + + requests = self._requests + order = self._order + n = 0 + while requests and n <= 10: + wait_on_fail = (2 ** n) if (2 ** n) < 60 else 60 + randomness = float(random.randint(1,1000)) / 1000 + wait_on_fail = wait_on_fail + randomness + if n > 3: + sys.stderr.write('\n retrying %s requests after backing off %s seconds...\n' + % (len(requests), int(wait_on_fail))) + if n > 0: + time.sleep(wait_on_fail) + self._execute(http, order, requests) + n += 1 + requests = {} + order = [] + for request_id in self._order: + resp, content = self._responses[request_id] + if resp['status'] == '401': + order.append(request_id) + request = self._requests[request_id] + self._refresh_and_apply_credentials(request, http) + requests[request_id] = request + try: + error = json.loads(content.decode("utf-8")) + except ValueError: + continue + try: + reason = error['error']['errors'][0]['reason'] + except KeyError: + continue + if reason in ['limitExceeded', 'servingLimitExceeded', + 'rateLimitExceeded', 'userRateLimitExceeded', + 'backendError', 'internalError']: + order.append(request_id) + requests[request_id] = self._requests[request_id] # Now process all callbacks that are erroring, and raise an exception for # ones that return a non-2xx response? Or add extra parameter to callback @@ -1250,7 +1371,7 @@ def execute(self, http=None): if resp.status >= 300: raise HttpError(resp, content, uri=request.uri) response = request.postproc(resp, content) - except HttpError, e: + except HttpError as e: exception = e if callback is not None: @@ -1308,7 +1429,7 @@ class RequestMockBuilder(object): 'plus.activities.get': (None, response), } ) - apiclient.discovery.build("plus", "v1", requestBuilder=requestBuilder) + googleapiclient.discovery.build("plus", "v1", requestBuilder=requestBuilder) Methods that you do not supply a response for will return a 200 OK with an empty string as the response content or raise an excpetion @@ -1352,8 +1473,8 @@ def __call__(self, http, postproc, uri, method='GET', body=None, # or expecting a body and not provided one. raise UnexpectedBodyError(expected_body, body) if isinstance(expected_body, str): - expected_body = simplejson.loads(expected_body) - body = simplejson.loads(body) + expected_body = json.loads(expected_body) + body = json.loads(body) if body != expected_body: raise UnexpectedBodyError(expected_body, body) return HttpRequestMock(resp, content, postproc) @@ -1374,9 +1495,9 @@ def __init__(self, filename=None, headers=None): headers: dict, header to return with response """ if headers is None: - headers = {'status': '200 OK'} + headers = {'status': '200'} if filename: - f = file(filename, 'r') + f = open(filename, 'r') self.data = f.read() f.close() else: @@ -1444,7 +1565,7 @@ def request(self, uri, if content == 'echo_request_headers': content = headers elif content == 'echo_request_headers_as_json': - content = simplejson.dumps(headers) + content = json.dumps(headers) elif content == 'echo_request_body': if hasattr(body, 'read'): content = body.read() @@ -1452,6 +1573,8 @@ def request(self, uri, content = body elif content == 'echo_request_uri': content = uri + if isinstance(content, six.text_type): + content = content.encode('utf-8') return httplib2.Response(resp), content diff --git a/apiclient/mimeparse.py b/googleapiclient/mimeparse.py similarity index 95% rename from apiclient/mimeparse.py rename to googleapiclient/mimeparse.py index cbb9d07..bc9ad09 100644 --- a/apiclient/mimeparse.py +++ b/googleapiclient/mimeparse.py @@ -1,4 +1,4 @@ -# Copyright (C) 2007 Joe Gregorio +# Copyright 2014 Joe Gregorio # # Licensed under the MIT License @@ -21,6 +21,9 @@ - best_match(): Choose the mime-type with the highest quality ('q') from a list of candidates. """ +from __future__ import absolute_import +from functools import reduce +import six __version__ = '0.1.3' __author__ = 'Joe Gregorio' @@ -68,7 +71,7 @@ def parse_media_range(range): necessary. """ (type, subtype, params) = parse_mime_type(range) - if not params.has_key('q') or not params['q'] or \ + if 'q' not in params or not params['q'] or \ not float(params['q']) or float(params['q']) > 1\ or float(params['q']) < 0: params['q'] = '1' @@ -98,8 +101,8 @@ def fitness_and_quality_parsed(mime_type, parsed_ranges): target_subtype == '*') if type_match and subtype_match: param_matches = reduce(lambda x, y: x + y, [1 for (key, value) in \ - target_params.iteritems() if key != 'q' and \ - params.has_key(key) and value == params[key]], 0) + six.iteritems(target_params) if key != 'q' and \ + key in params and value == params[key]], 0) fitness = (type == target_type) and 100 or 0 fitness += (subtype == target_subtype) and 10 or 0 fitness += param_matches diff --git a/apiclient/model.py b/googleapiclient/model.py similarity index 90% rename from apiclient/model.py rename to googleapiclient/model.py index 12fcab6..e8afb63 100644 --- a/apiclient/model.py +++ b/googleapiclient/model.py @@ -1,6 +1,4 @@ -#!/usr/bin/python2.4 -# -# Copyright (C) 2010 Google Inc. +# Copyright 2014 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,21 +19,21 @@ for converting between the wire format and the Python object representation. """ +from __future__ import absolute_import +import six __author__ = 'jcgregorio@google.com (Joe Gregorio)' -import gflags +import json import logging -import urllib -from errors import HttpError -from oauth2client.anyjson import simplejson +from six.moves.urllib.parse import urlencode + +from googleapiclient import __version__ +from googleapiclient.errors import HttpError -FLAGS = gflags.FLAGS -gflags.DEFINE_boolean('dump_request_response', False, - 'Dump all http server requests and responses. ' - ) +dump_request_response = False def _abstract(): @@ -80,7 +78,7 @@ def response(self, resp, content): The body de-serialized as a Python object. Raises: - apiclient.errors.HttpError if a non 2xx response is received. + googleapiclient.errors.HttpError if a non 2xx response is received. """ _abstract() @@ -106,14 +104,14 @@ class BaseModel(Model): def _log_request(self, headers, path_params, query, body): """Logs debugging information about the request if requested.""" - if FLAGS.dump_request_response: + if dump_request_response: logging.info('--request-start--') logging.info('-headers-start-') - for h, v in headers.iteritems(): + for h, v in six.iteritems(headers): logging.info('%s: %s', h, v) logging.info('-headers-end-') logging.info('-path-parameters-start-') - for h, v in path_params.iteritems(): + for h, v in six.iteritems(path_params): logging.info('%s: %s', h, v) logging.info('-path-parameters-end-') logging.info('body: %s', body) @@ -128,7 +126,7 @@ def request(self, headers, path_params, query_params, body_value): path_params: dict, parameters that appear in the request path query_params: dict, parameters that appear in the query body_value: object, the request body as a Python object, which must be - serializable by simplejson. + serializable by json. Returns: A tuple of (headers, path_params, query, body) @@ -144,7 +142,7 @@ def request(self, headers, path_params, query_params, body_value): headers['user-agent'] += ' ' else: headers['user-agent'] = '' - headers['user-agent'] += 'google-api-python-client/1.0' + headers['user-agent'] += 'google-api-python-client/%s (gzip)' % __version__ if body_value is not None: headers['content-type'] = self.content_type @@ -164,22 +162,22 @@ def _build_query(self, params): if self.alt_param is not None: params.update({'alt': self.alt_param}) astuples = [] - for key, value in params.iteritems(): + for key, value in six.iteritems(params): if type(value) == type([]): for x in value: x = x.encode('utf-8') astuples.append((key, x)) else: - if getattr(value, 'encode', False) and callable(value.encode): + if isinstance(value, six.text_type) and callable(value.encode): value = value.encode('utf-8') astuples.append((key, value)) - return '?' + urllib.urlencode(astuples) + return '?' + urlencode(astuples) def _log_response(self, resp, content): """Logs debugging information about the response if requested.""" - if FLAGS.dump_request_response: + if dump_request_response: logging.info('--response-start--') - for h, v in resp.iteritems(): + for h, v in six.iteritems(resp): logging.info('%s: %s', h, v) if content: logging.info(content) @@ -196,7 +194,7 @@ def response(self, resp, content): The body de-serialized as a Python object. Raises: - apiclient.errors.HttpError if a non 2xx response is received. + googleapiclient.errors.HttpError if a non 2xx response is received. """ self._log_response(resp, content) # Error handling is TBD, for example, do we retry @@ -257,10 +255,14 @@ def serialize(self, body_value): if (isinstance(body_value, dict) and 'data' not in body_value and self._data_wrapper): body_value = {'data': body_value} - return simplejson.dumps(body_value) + return json.dumps(body_value) def deserialize(self, content): - body = simplejson.loads(content) + try: + content = content.decode('utf-8') + except AttributeError: + pass + body = json.loads(content) if self._data_wrapper and isinstance(body, dict) and 'data' in body: body = body['data'] return body @@ -363,7 +365,7 @@ def makepatch(original, modified): body=makepatch(original, item)).execute() """ patch = {} - for key, original_value in original.iteritems(): + for key, original_value in six.iteritems(original): modified_value = modified.get(key, None) if modified_value is None: # Use None to signal that the element is deleted diff --git a/googleapiclient/sample_tools.py b/googleapiclient/sample_tools.py new file mode 100644 index 0000000..2b4e7b4 --- /dev/null +++ b/googleapiclient/sample_tools.py @@ -0,0 +1,103 @@ +# Copyright 2014 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for making samples. + +Consolidates a lot of code commonly repeated in sample applications. +""" +from __future__ import absolute_import + +__author__ = 'jcgregorio@google.com (Joe Gregorio)' +__all__ = ['init'] + + +import argparse +import httplib2 +import os + +from googleapiclient import discovery +from oauth2client import client +from oauth2client import file +from oauth2client import tools + + +def init(argv, name, version, doc, filename, scope=None, parents=[], discovery_filename=None): + """A common initialization routine for samples. + + Many of the sample applications do the same initialization, which has now + been consolidated into this function. This function uses common idioms found + in almost all the samples, i.e. for an API with name 'apiname', the + credentials are stored in a file named apiname.dat, and the + client_secrets.json file is stored in the same directory as the application + main file. + + Args: + argv: list of string, the command-line parameters of the application. + name: string, name of the API. + version: string, version of the API. + doc: string, description of the application. Usually set to __doc__. + file: string, filename of the application. Usually set to __file__. + parents: list of argparse.ArgumentParser, additional command-line flags. + scope: string, The OAuth scope used. + discovery_filename: string, name of local discovery file (JSON). Use when discovery doc not available via URL. + + Returns: + A tuple of (service, flags), where service is the service object and flags + is the parsed command-line flags. + """ + if scope is None: + scope = 'https://www.googleapis.com/auth/' + name + + # Parser command-line arguments. + parent_parsers = [tools.argparser] + parent_parsers.extend(parents) + parser = argparse.ArgumentParser( + description=doc, + formatter_class=argparse.RawDescriptionHelpFormatter, + parents=parent_parsers) + flags = parser.parse_args(argv[1:]) + + # Name of a file containing the OAuth 2.0 information for this + # application, including client_id and client_secret, which are found + # on the API Access tab on the Google APIs + # Console . + client_secrets = os.path.join(os.path.dirname(filename), + 'client_secrets.json') + + # Set up a Flow object to be used if we need to authenticate. + flow = client.flow_from_clientsecrets(client_secrets, + scope=scope, + message=tools.message_if_missing(client_secrets)) + + # Prepare credentials, and authorize HTTP object with them. + # If the credentials don't exist or are invalid run through the native client + # flow. The Storage object will ensure that if successful the good + # credentials will get written back to a file. + storage = file.Storage(name + '.dat') + credentials = storage.get() + if credentials is None or credentials.invalid: + credentials = tools.run_flow(flow, storage, flags) + http = credentials.authorize(http = httplib2.Http()) + + if discovery_filename is None: + # Construct a service object via the discovery service. + service = discovery.build(name, version, http=http) + else: + # Construct a service object using a local discovery document file. + with open(discovery_filename) as discovery_file: + service = discovery.build_from_document( + discovery_file.read(), + base='https://www.googleapis.com/', + http=http) + return (service, flags) diff --git a/apiclient/schema.py b/googleapiclient/schema.py similarity index 98% rename from apiclient/schema.py rename to googleapiclient/schema.py index d076a86..ecb3f8b 100644 --- a/apiclient/schema.py +++ b/googleapiclient/schema.py @@ -1,4 +1,4 @@ -# Copyright (C) 2010 Google Inc. +# Copyright 2014 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -56,6 +56,8 @@ The constructor takes a discovery document in which to look up named schema. """ +from __future__ import absolute_import +import six # TODO(jcgregorio) support format, enum, minimum, maximum @@ -64,7 +66,6 @@ import copy from oauth2client import util -from oauth2client.anyjson import simplejson class Schemas(object): @@ -250,7 +251,7 @@ def _to_str_impl(self, schema): self.emitEnd('{', schema.get('description', '')) self.indent() if 'properties' in schema: - for pname, pschema in schema.get('properties', {}).iteritems(): + for pname, pschema in six.iteritems(schema.get('properties', {})): self.emitBegin('"%s": ' % pname) self._to_str_impl(pschema) elif 'additionalProperties' in schema: diff --git a/gyb.py b/gyb.py index 7e21376..29c4336 100644 --- a/gyb.py +++ b/gyb.py @@ -1,1341 +1,1499 @@ -#!/usr/bin/env python- -# -# Got Your Back -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -u"""\n%s\n\nGot Your Back (GYB) is a command line tool which allows users to backup and restore their Gmail. - -For more information, see http://code.google.com/p/got-your-back/ -""" - -global __name__, __author__, __email__, __version__, __license__ -__program_name__ = u'Got Your Back: Gmail Backup' -__author__ = u'Jay Lee' -__email__ = u'jay0lee@gmail.com' -__version__ = u'0.31' -__license__ = u'Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0)' -__db_schema_version__ = u'5' -__db_schema_min_version__ = u'2' #Minimum for restore - -import imaplib -from optparse import OptionParser, SUPPRESS_HELP -import sys -import os -import os.path -import time -import random -import struct -import platform -import StringIO -import socket -import datetime -import sqlite3 -import email -import mailbox -import mimetypes -import re -import shlex -from itertools import islice, chain -import math - -try: - import json as simplejson -except ImportError: - import simplejson - -import httplib2 -import oauth2client.client -import oauth2client.file -import oauth2client.tools -import gflags -import apiclient -import apiclient.discovery -import apiclient.errors -import gimaplib - -def SetupOptionParser(): - # Usage message is the module's docstring. - parser = OptionParser(usage=__doc__ % getGYBVersion(), add_help_option=False) - parser.add_option('--email', - dest='email', - help='Full email address of user or group to act against') - action_choices = ['backup','restore', 'restore-group', 'restore-mbox', 'count', 'purge', 'purge-labels', 'estimate', 'quota', 'reindex'] - parser.add_option('--action', - type='choice', - choices=action_choices, - dest='action', - default='backup', - help='Action to perform - %s. Default is backup.' % ', '.join(action_choices)) - parser.add_option('--search', - dest='gmail_search', - default='in:anywhere', - help='Optional: On backup, estimate, count and purge, Gmail search to scope operation against') - parser.add_option('--local-folder', - dest='local_folder', - help='Optional: On backup, restore, estimate, local folder to use. Default is GYB-GMail-Backup-', - default='XXXuse-email-addressXXX') - parser.add_option('--use-imap-folder', - dest='use_folder', - help='Optional: IMAP folder to act against. Default is "All Mail" label. You can run "--use_folder [Gmail]/Chats" to backup chat.') - parser.add_option('--label-restored', - dest='label_restored', - help='Optional: On restore, all messages will additionally receive this label. For example, "--label_restored gyb-restored" will label all uploaded messages with a gyb-restored label.') - parser.add_option('--strip-labels', - dest='strip_labels', - action='store_true', - default=False, - help='Optional: On restore and restore-mbox, strip existing labels from messages except for those explicitly declared with the --label-restored parameter.') - parser.add_option('--service-account', - dest='service_account', - help='Google Apps Business and Education only. Use OAuth 2.0 Service Account to authenticate.') - parser.add_option('--use-admin', - dest='use_admin', - help='Optional. On restore-group, authenticate as this admin user.') - parser.add_option('--batch-size', - dest='batch_size', - type='int', - default=100, - help='Optional: On backup, sets the number of messages to batch download.') - parser.add_option('--noresume', - action='store_true', - default=False, - help='Optional: On restores, start from beginning. Default is to resume where last restore left off.') - parser.add_option('--fast-incremental', - dest='refresh', - action='store_false', - default=True, - help='Optional: On backup, skips refreshing labels for existing message') - parser.add_option('--debug', - action='store_true', - dest='debug', - help='Turn on verbose debugging and connection information (troubleshooting)') - parser.add_option('--version', - action='store_true', - dest='version', - help='print GYB version and quit') - parser.add_option('--help', - action='help', - help='Display this message.') - return parser - -def win32_unicode_argv(): - from ctypes import POINTER, byref, cdll, c_int, windll - from ctypes.wintypes import LPCWSTR, LPWSTR - - GetCommandLineW = cdll.kernel32.GetCommandLineW - GetCommandLineW.argtypes = [] - GetCommandLineW.restype = LPCWSTR - - CommandLineToArgvW = windll.shell32.CommandLineToArgvW - CommandLineToArgvW.argtypes = [LPCWSTR, POINTER(c_int)] - CommandLineToArgvW.restype = POINTER(LPWSTR) - - cmd = GetCommandLineW() - argc = c_int(0) - argv = CommandLineToArgvW(cmd, byref(argc)) - if argc.value > 0: - # Remove Python executable and commands if present - start = argc.value - len(sys.argv) - return [argv[i] for i in xrange(start, argc.value)] - -def getProgPath(): - if os.path.abspath('/') != -1: - divider = '/' - else: - divider = '\\' - return os.path.dirname(os.path.realpath(sys.argv[0]))+divider - -def batch(iterable, size): - sourceiter = iter(iterable) - while True: - batchiter = islice(sourceiter, size) - yield chain([batchiter.next()], batchiter) - -def getOAuthFromConfigFile(email): - cfgFile = '%s%s.cfg' % (getProgPath(), email) - if os.path.isfile(cfgFile): - f = open(cfgFile, 'r') - key = f.readline()[0:-1] - secret = f.readline() - f.close() - return (key, secret) - else: - return (False, False) - -def requestOAuthAccess(email, debug=False): - scopes = ['https://mail.google.com/', # IMAP/SMTP client access - 'https://www.googleapis.com/auth/userinfo#email', - 'https://www.googleapis.com/auth/apps.groups.migration'] - CLIENT_SECRETS = getProgPath()+'client_secrets.json' - MISSING_CLIENT_SECRETS_MESSAGE = """ -WARNING: Please configure OAuth 2.0 - -To make GYB run you will need to populate the client_secrets.json file -found at: - - %s - -with information from the APIs Console . - -""" % (CLIENT_SECRETS) - FLOW = oauth2client.client.flow_from_clientsecrets(CLIENT_SECRETS, scope=scopes, message=MISSING_CLIENT_SECRETS_MESSAGE, login_hint=email) - cfgFile = '%s%s.cfg' % (getProgPath(), email) - storage = oauth2client.file.Storage(cfgFile) - credentials = storage.get() - if os.path.isfile(getProgPath()+'nobrowser.txt'): - gflags.FLAGS.auth_local_webserver = False - if credentials is None or credentials.invalid: - certFile = getProgPath()+'cacert.pem' - disable_ssl_certificate_validation = False - if os.path.isfile(getProgPath()+'noverifyssl.txt'): - disable_ssl_certificate_validation = True - http = httplib2.Http(ca_certs=certFile, disable_ssl_certificate_validation=disable_ssl_certificate_validation) - credentials = oauth2client.tools.run(FLOW, storage, short_url=True, http=http) - -def doGYBCheckForUpdates(): - import urllib2, calendar - no_update_check_file = getProgPath()+'noupdatecheck.txt' - last_update_check_file = getProgPath()+'lastcheck.txt' - if os.path.isfile(no_update_check_file): return - try: - current_version = float(__version__) - except ValueError: - return - if os.path.isfile(last_update_check_file): - f = open(last_update_check_file, 'r') - last_check_time = int(f.readline()) - f.close() - else: - last_check_time = 0 - now_time = calendar.timegm(time.gmtime()) - one_week_ago_time = now_time - 604800 - if last_check_time > one_week_ago_time: return - try: - c = urllib2.urlopen(u'https://gyb-update.appspot.com/latest-version.txt?v=%s' % __version__) - try: - latest_version = float(c.read()) - except ValueError: - return - if latest_version <= current_version: - f = open(last_update_check_file, 'w') - f.write(str(now_time)) - f.close() - return - a = urllib2.urlopen(u'https://gyb-update.appspot.com/latest-version-announcement.txt?v=%s') - announcement = a.read() - sys.stderr.write('\nThere\'s a new version of GYB!!!\n\n') - sys.stderr.write(announcement) - visit_gyb = raw_input(u"\n\nHit Y to visit the GYB website and download the latest release. Hit Enter to just continue with this boring old version. GYB won't bother you with this announcemnt for 1 week or you can create a file named %s and GYB won't ever check for updates: " % no_update_check_file) - if visit_gyb.lower() == u'y': - import webbrowser - webbrowser.open(u'http://git.io/gyb') - print u'GYB is now exiting so that you can overwrite this old version with the latest release' - sys.exit(0) - f = open(last_update_check_file, 'w') - f.write(str(now_time)) - f.close() - except urllib2.HTTPError: - return - except urllib2.URLError: - return - -def generateXOAuthString(email, service_account=False, debug=False): - if debug: - httplib2.debuglevel = 4 - if service_account: - try: - f = file(getProgPath()+'privatekey.p12', 'rb') - key = f.read() - f.close() - service_account_name = service_account - except IOError: - json_string = file(getProgPath()+'privatekey.json', 'rb').read() - json_data = simplejson.loads(json_string) - key = json_data['private_key'] - service_account_name = json_data['client_email'] - scope = 'https://mail.google.com/' - credentials = oauth2client.client.SignedJwtAssertionCredentials(service_account_name=service_account_name, private_key=key, scope=scope, user_agent=getGYBVersion(' / '), prn=email) - disable_ssl_certificate_validation = False - if os.path.isfile(getProgPath()+'noverifyssl.txt'): - disable_ssl_certificate_validation = True - http = httplib2.Http(ca_certs=getProgPath()+'cacert.pem', disable_ssl_certificate_validation=disable_ssl_certificate_validation) - if debug: - httplib2.debuglevel = 4 - http = credentials.authorize(http) - service = apiclient.discovery.build('oauth2', 'v2', http=http) - else: - cfgFile = '%s%s.cfg' % (getProgPath(), email) - storage = oauth2client.file.Storage(cfgFile) - credentials = storage.get() - if credentials is None or credentials.invalid: - requestOAuthAccess(email, debug) - credentials = storage.get() - if credentials.access_token_expired: - disable_ssl_certificate_validation = False - if os.path.isfile(getProgPath()+'noverifyssl.txt'): - disable_ssl_certificate_validation = True - credentials.refresh(httplib2.Http(ca_certs=getProgPath()+'cacert.pem', disable_ssl_certificate_validation=disable_ssl_certificate_validation)) - return "user=%s\001auth=OAuth %s\001\001" % (email, credentials.access_token) - -def just_quote(self, arg): - return '"%s"' % arg - -def callGAPI(service, function, soft_errors=False, throw_reasons=[], **kwargs): - method = getattr(service, function) - retries = 3 - for n in range(1, retries+1): - try: - return method(**kwargs).execute() - except apiclient.errors.HttpError, e: - error = simplejson.loads(e.content) - try: - reason = error['error']['errors'][0]['reason'] - http_status = error['error']['code'] - message = error['error']['errors'][0]['message'] - if reason in throw_reasons: - raise - if n != retries and reason in ['rateLimitExceeded', 'userRateLimitExceeded', 'backendError']: - wait_on_fail = (2 ** n) if (2 ** n) < 60 else 60 - randomness = float(random.randint(1,1000)) / 1000 - wait_on_fail = wait_on_fail + randomness - if n > 3: sys.stderr.write('\nTemp error %s. Backing off %s seconds...' % (reason, int(wait_on_fail))) - time.sleep(wait_on_fail) - if n > 3: sys.stderr.write('attempt %s/%s\n' % (n+1, retries)) - continue - sys.stderr.write('\n%s: %s - %s\n' % (http_status, message, reason)) - if soft_errors: - sys.stderr.write(' - Giving up.\n') - return - else: - sys.exit(int(http_status)) - except KeyError: - sys.stderr.write('Unknown Error: %s' % e) - sys.exit(1) - except oauth2client.client.AccessTokenRefreshError, e: - sys.stderr.write('Error: Authentication Token Error - %s' % e) - sys.exit(403) - -def message_is_backed_up(message_num, sqlcur, sqlconn, backup_folder): - try: - sqlcur.execute(''' - SELECT message_filename FROM uids NATURAL JOIN messages - where uid = ?''', ((message_num),)) - except sqlite3.OperationalError, e: - if e.message == 'no such table: messages': - print "\n\nError: your backup database file appears to be corrupted." - else: - print "SQL error:%s" % e - sys.exit(8) - sqlresults = sqlcur.fetchall() - for x in sqlresults: - filename = x[0] - if os.path.isfile(os.path.join(backup_folder, filename)): - return True - return False - -def get_db_settings(sqlcur): - try: - sqlcur.execute('SELECT name, value FROM settings') - db_settings = dict(sqlcur) - return db_settings - except sqlite3.OperationalError, e: - if e.message == 'no such table: settings': - print "\n\nSorry, this version of GYB requires version %s of the database schema. Your backup folder database does not have a version." % (__db_schema_version__) - sys.exit(6) - else: - print "%s" % e - -def check_db_settings(db_settings, action, user_email_address): - if (db_settings['db_version'] < __db_schema_min_version__ or - db_settings['db_version'] > __db_schema_version__): - print "\n\nSorry, this backup folder was created with version %s of the database schema while GYB %s requires version %s - %s for restores" % (db_settings['db_version'], __version__, __db_schema_min_version__, __db_schema_version__) - sys.exit(4) - - # Only restores are allowed to use a backup folder started with another account (can't allow 2 Google Accounts to backup/estimate from same folder) - if action not in ['restore', 'restore-group', 'restore-mbox']: - if user_email_address.lower() != db_settings['email_address'].lower(): - print "\n\nSorry, this backup folder should only be used with the %s account that it was created with for incremental backups. You specified the %s account" % (db_settings['email_address'], user_email_address) - sys.exit(5) - -def convertDB(sqlconn, uidvalidity, oldversion): - print "Converting database" - try: - with sqlconn: - if oldversion < '3': - # Convert to schema 3 - sqlconn.executescript(''' - BEGIN; - CREATE TABLE uids - (message_num INTEGER, uid INTEGER PRIMARY KEY); - INSERT INTO uids (uid, message_num) - SELECT message_num as uid, message_num FROM messages; - CREATE INDEX labelidx ON labels (message_num); - CREATE INDEX flagidx ON flags (message_num); - ''') - if oldversion < '4': - # Convert to schema 4 - sqlconn.execute(''' - ALTER TABLE messages ADD COLUMN rfc822_msgid TEXT; - ''') - if oldversion < '5': - # Convert to schema 5 - sqlconn.executescript(''' - DROP INDEX labelidx; - DROP INDEX flagidx; - CREATE UNIQUE INDEX labelidx ON labels (message_num, label); - CREATE UNIQUE INDEX flagidx ON flags (message_num, flag); - ''') - sqlconn.executemany('REPLACE INTO settings (name, value) VALUES (?,?)', - (('uidvalidity',uidvalidity), - ('db_version', __db_schema_version__)) ) - sqlconn.commit() - except sqlite3.OperationalError, e: - print "Conversion error: %s" % e.message - - print "GYB database converted to version %s" % __db_schema_version__ - -def getMessageIDs (sqlconn, backup_folder): - sqlcur = sqlconn.cursor() - header_parser = email.parser.HeaderParser() - for message_num, filename in sqlconn.execute(''' - SELECT message_num, message_filename FROM messages - WHERE rfc822_msgid IS NULL'''): - message_full_filename = os.path.join(backup_folder, filename) - if os.path.isfile(message_full_filename): - f = open(message_full_filename, 'rb') - msgid = header_parser.parse(f, True).get('message-id') or '' - f.close() - sqlcur.execute( - 'UPDATE messages SET rfc822_msgid = ? WHERE message_num = ?', - (msgid, message_num)) - sqlconn.commit() - -def rebuildUIDTable(imapconn, sqlconn): - sqlcur = sqlconn.cursor() - header_parser = email.parser.HeaderParser() - sqlcur.execute('DELETE FROM uids') - # Create an index on the Message ID to speed up the process - sqlcur.execute('CREATE INDEX IF NOT EXISTS msgidx on messages(rfc822_msgid)') - exists = imapconn.response('exists') - exists = int(exists[1][0]) - batch_size = 1000 - for batch_start in xrange(1, exists, batch_size): - batch_end = min(exists, batch_start+batch_size-1) - t, d = imapconn.fetch('%d:%d' % (batch_start, batch_end), - '(UID INTERNALDATE BODY.PEEK[HEADER.FIELDS ' - '(FROM TO SUBJECT MESSAGE-ID)])') - if t != 'OK': - print "\nError: failed to retrieve messages." - print "%s %s" % (t, d) - sys.exit(5) - for extras, header in (x for x in d if x != ')'): - uid, message_date = re.search('UID ([0-9]*) (INTERNALDATE \".*\")', - extras).groups() - try: - time_seconds = time.mktime(imaplib.Internaldate2tuple(message_date)) - except OverflowError: - time_seconds = time.time() - message_internaldate = datetime.datetime.fromtimestamp(time_seconds) - m = header_parser.parsestr(header, True) - msgid = m.get('message-id') or '' - message_to = m.get('to') - message_from = m.get('from') - message_subject = m.get('subject') - try: - sqlcur.execute(''' - INSERT INTO uids (uid, message_num) - SELECT ?, message_num FROM messages WHERE - rfc822_msgid = ? AND - message_internaldate = ? - GROUP BY rfc822_msgid - HAVING count(*) = 1''', - (uid, - msgid, - message_internaldate)) - except Exception, e: - print e - print e.message - print uid, msgid - if sqlcur.lastrowid is None: - print uid, rfc822_msgid - print "\b.", - sys.stdout.flush() - # There is no need to maintain the Index for normal operations - sqlcur.execute('DROP INDEX msgidx') - sqlconn.commit() - -def doesTokenMatchEmail(cli_email, debug=False): - cfgFile = '%s%s.cfg' % (getProgPath(), cli_email) - storage = oauth2client.file.Storage(cfgFile) - credentials = storage.get() - disable_ssl_certificate_validation = False - if os.path.isfile(getProgPath()+'noverifyssl.txt'): - disable_ssl_certificate_validation = True - http = httplib2.Http(ca_certs=getProgPath()+'cacert.pem', disable_ssl_certificate_validation=disable_ssl_certificate_validation) - if debug: - httplib2.debuglevel = 4 - if credentials.access_token_expired: - credentials.refresh(http) - oa2 = apiclient.discovery.build('oauth2', 'v2', http=http) - token_info = callGAPI(service=oa2, function='tokeninfo', access_token=credentials.access_token) - if token_info['email'].lower() == cli_email.lower(): - return True - return False - -def restart_line(): - sys.stdout.write('\r') - sys.stdout.flush() - -def initializeDB(sqlcur, sqlconn, email, uidvalidity): - sqlcur.executescript(''' - CREATE TABLE messages(message_num INTEGER PRIMARY KEY, - message_filename TEXT, - message_to TEXT, - message_from TEXT, - message_subject TEXT, - message_internaldate TIMESTAMP, - rfc822_msgid TEXT); - CREATE TABLE labels (message_num INTEGER, label TEXT); - CREATE TABLE flags (message_num INTEGER, flag TEXT); - CREATE TABLE uids (message_num INTEGER, uid INTEGER PRIMARY KEY); - CREATE TABLE settings (name TEXT PRIMARY KEY, value TEXT); - CREATE UNIQUE INDEX labelidx ON labels (message_num, label); - CREATE UNIQUE INDEX flagidx ON flags (message_num, flag); - ''') - sqlcur.executemany('INSERT INTO settings (name, value) VALUES (?, ?)', - (('email_address', email), - ('db_version', __db_schema_version__), - ('uidvalidity', uidvalidity))) - sqlconn.commit() - -def get_message_size(imapconn, uids): - if type(uids) == type(int()): - uid_string == str(uid) - else: - uid_string = ','.join(uids) - t, d = imapconn.uid('FETCH', uid_string, '(RFC822.SIZE)') - if t != 'OK': - print "Failed to retrieve size for message %s" % uid - print "%s %s" % (t, d) - exit(9) - total_size = 0 - for x in d: - try: - message_size = int(re.search('^[0-9]* \(UID [0-9]* RFC822.SIZE ([0-9]*)\)$', x).group(1)) - total_size = total_size + message_size - except AttributeError: - pass - return total_size - -def getGYBVersion(divider="\n"): - return ('Got Your Back %s~DIV~%s - %s~DIV~Python %s.%s.%s %s-bit %s~DIV~%s %s' % (__version__, __author__, __email__, - sys.version_info[0], sys.version_info[1], sys.version_info[2], struct.calcsize('P')*8, - sys.version_info[3], platform.platform(), platform.machine())).replace('~DIV~', divider) - -def main(argv): - options_parser = SetupOptionParser() - (options, args) = options_parser.parse_args(args=argv) - if options.version: - print getGYBVersion() - sys.exit(0) - if not options.email: - options_parser.print_help() - print "\nERROR: --email is required." - return - if options.local_folder == 'XXXuse-email-addressXXX': - options.local_folder = "GYB-GMail-Backup-%s" % options.email - if not options.service_account: # 3-Legged OAuth - if options.use_admin: - auth_as = options.use_admin - else: - auth_as = options.email - requestOAuthAccess(auth_as, options.debug) - if not doesTokenMatchEmail(auth_as, options.debug): - print "Error: you did not authorize the OAuth token in the browser with the %s Google Account. Please make sure you are logged in to the correct account when authorizing the token in the browser." % auth_as - cfgFile = '%s%s.cfg' % (getProgPath(), auth_as) - os.remove(cfgFile) - sys.exit(9) - - if not os.path.isdir(options.local_folder): - if options.action in ['backup',]: - os.mkdir(options.local_folder) - elif options.action in ['restore', 'restore-group']: - print 'Error: Folder %s does not exist. Cannot restore.' % options.local_folder - sys.exit(3) - - if options.action not in ['restore-group']: - imapconn = gimaplib.ImapConnect(generateXOAuthString(options.email, options.service_account), options.debug) - global ALL_MAIL, TRASH, SPAM - label_mappings = gimaplib.GImapGetFolders(imapconn) - try: - ALL_MAIL = label_mappings[u'\\All'] - except KeyError: - print 'Error: Cannot find the Gmail "All Mail" folder. Please make sure it is not hidden from IMAP.' - sys.exit(3) - if not options.use_folder: - options.use_folder = ALL_MAIL - r, d = imapconn.select(ALL_MAIL, readonly=True) - if r == 'NO': - print "Error: Cannot select the Gmail \"All Mail\" folder. Please make sure it is not hidden from IMAP." - sys.exit(3) - uidvalidity = imapconn.response('UIDVALIDITY')[1][0] - - sqldbfile = os.path.join(options.local_folder, 'msg-db.sqlite') - # Do we need to initialize a new database? - newDB = (not os.path.isfile(sqldbfile)) and (options.action in ['backup', u'restore-mbox']) - - #If we're not doing a estimate or if the db file actually exists we open it (creates db if it doesn't exist) - if options.action not in ['count', 'purge', 'purge-labels', 'quota']: - if options.action not in ['estimate'] or os.path.isfile(sqldbfile): - print "\nUsing backup folder %s" % options.local_folder - global sqlconn - global sqlcur - sqlconn = sqlite3.connect(sqldbfile, detect_types=sqlite3.PARSE_DECLTYPES) - sqlconn.text_factory = str - sqlcur = sqlconn.cursor() - if newDB: - initializeDB(sqlcur, sqlconn, options.email, uidvalidity) - db_settings = get_db_settings(sqlcur) - check_db_settings(db_settings, options.action, options.email) - if options.action not in ['restore', 'restore-group', u'restore-mbox']: - if ('uidvalidity' not in db_settings or - db_settings['db_version'] < __db_schema_version__): - convertDB(sqlconn, uidvalidity, db_settings['db_version']) - db_settings = get_db_settings(sqlcur) - if options.action == 'reindex': - getMessageIDs(sqlconn, options.local_folder) - rebuildUIDTable(imapconn, sqlconn) - sqlconn.execute(''' - UPDATE settings SET value = ? where name = 'uidvalidity' - ''', ((uidvalidity),)) - sqlconn.commit() - sys.exit(0) - - if db_settings['uidvalidity'] != uidvalidity: - print "Because of changes on the Gmail server, this folder cannot be used for incremental backups." - sys.exit(3) - - # BACKUP # - if options.action == 'backup': - print 'Using folder %s' % options.use_folder - imapconn.select(options.use_folder, readonly=True) - messages_to_process = gimaplib.GImapSearch(imapconn, options.gmail_search) - backup_path = options.local_folder - if not os.path.isdir(backup_path): - os.mkdir(backup_path) - messages_to_backup = [] - messages_to_refresh = [] - #Determine which messages from the search we haven't processed before. - print "GYB needs to examine %s messages" % len(messages_to_process) - for message_num in messages_to_process: - if not newDB and message_is_backed_up(message_num, sqlcur, sqlconn, options.local_folder): - messages_to_refresh.append(message_num) - else: - messages_to_backup.append(message_num) - print "GYB already has a backup of %s messages" % (len(messages_to_process) - len(messages_to_backup)) - backup_count = len(messages_to_backup) - print "GYB needs to backup %s messages" % backup_count - messages_at_once = options.batch_size - backed_up_messages = 0 - header_parser = email.parser.HeaderParser() - for working_messages in batch(messages_to_backup, messages_at_once): - working_messages=list(working_messages) - #Save message content - batch_string = ','.join(working_messages) - bad_count = 0 - while True: - try: - r, d = imapconn.uid('FETCH', batch_string, '(X-GM-LABELS INTERNALDATE FLAGS BODY.PEEK[])') - if r != 'OK': - bad_count = bad_count + 1 - if bad_count > 7: - print "\nError: failed to retrieve messages." - print "%s %s" % (r, d) - sys.exit(5) - sleep_time = math.pow(2, bad_count) - sys.stdout.write("\nServer responded with %s %s, will retry in %s seconds" % (r, d, str(sleep_time))) - time.sleep(sleep_time) # sleep 2 seconds, then 4, 8, 16, 32, 64, 128 - imapconn = gimaplib.ImapConnect(generateXOAuthString(options.email, options.service_account), options.debug) - imapconn.select(ALL_MAIL, readonly=True) - continue - break - except imaplib.IMAP4.abort, e: - print 'imaplib.abort error:%s, retrying...' % e - imapconn = gimaplib.ImapConnect(generateXOAuthString(options.email, options.service_account), options.debug) - imapconn.select(ALL_MAIL, readonly=True) - except socket.error, e: - print 'socket.error:%s, retrying...' % e - imapconn = gimaplib.ImapConnect(generateXOAuthString(options.email, options.service_account), options.debug) - imapconn.select(ALL_MAIL, readonly=True) - requested_count = len(working_messages) * 2 - d = d[:requested_count - 1] # cut off the extraneous responses for modified messages in the mailbox that we didn't request - for everything_else_string, full_message in (x for x in d if x != ')'): - search_results = re.search('X-GM-LABELS \((.*)\) UID ([0-9]*) (INTERNALDATE \".*\") (FLAGS \(.*\))', everything_else_string) - labels_str = search_results.group(1) - quoted_labels = shlex.split(labels_str, posix=False) - labels = [] - for label in quoted_labels: - if label[0] == '"' and label[-1] == '"': - label = label[1:-1] - if label[:2] == '\\\\': - label = label[1:] - labels.append(label) - uid = search_results.group(2) - message_date_string = search_results.group(3) - message_flags_string = search_results.group(4) - try: - message_date = imaplib.Internaldate2tuple(message_date_string) - except OverflowError: # Bad internal time? Use now... - message_date = time.gmtime() - time_seconds_since_epoch = time.mktime(message_date) - message_internal_datetime = datetime.datetime.fromtimestamp(time_seconds_since_epoch) - message_flags = imaplib.ParseFlags(message_flags_string) - message_file_name = "%s-%s.eml" % (uidvalidity, uid) - message_rel_path = os.path.join(str(message_date.tm_year), - str(message_date.tm_mon), - str(message_date.tm_mday)) - message_rel_filename = os.path.join(message_rel_path, - message_file_name) - message_full_path = os.path.join(options.local_folder, - message_rel_path) - message_full_filename = os.path.join(options.local_folder, - message_rel_filename) - if not os.path.isdir(message_full_path): - os.makedirs(message_full_path) - f = open(message_full_filename, 'wb') - f.write(full_message) - f.close() - m = header_parser.parsestr(full_message, True) - message_from = m.get('from') - message_to = m.get('to') - message_subj = m.get('subject') - message_id = m.get('message-id') - sqlcur.execute(""" - INSERT INTO messages ( - message_filename, - message_to, - message_from, - message_subject, - message_internaldate, - rfc822_msgid) VALUES (?, ?, ?, ?, ?, ?)""", - (message_rel_filename, - message_to, - message_from, - message_subj, - message_internal_datetime, - message_id)) - message_num = sqlcur.lastrowid - sqlcur.execute(""" - REPLACE INTO uids (message_num, uid) VALUES (?, ?)""", - (message_num, uid)) - for label in labels: - sqlcur.execute(""" - INSERT INTO labels (message_num, label) VALUES (?, ?)""", - (message_num, label)) - for flag in message_flags: - sqlcur.execute(""" - INSERT INTO flags (message_num, flag) VALUES (?, ?)""", - (message_num, flag)) - backed_up_messages += 1 - - sqlconn.commit() - restart_line() - sys.stdout.write("backed up %s of %s messages" % (backed_up_messages, backup_count)) - sys.stdout.flush() - print "\n" - - if not options.refresh: - messages_to_refresh = [] - backed_up_messages = 0 - backup_count = len(messages_to_refresh) - print "GYB needs to refresh %s messages" % backup_count - sqlcur.executescript(""" - CREATE TEMP TABLE current_labels (label TEXT); - CREATE TEMP TABLE current_flags (flag TEXT); - """) - messages_at_once *= 100 - for working_messages in batch(messages_to_refresh, messages_at_once): - #Save message content - batch_string = ','.join(working_messages) - bad_count = 0 - while True: - try: - r, d = imapconn.uid('FETCH', batch_string, '(X-GM-LABELS FLAGS)') - if r != 'OK': - bad_count = bad_count + 1 - if bad_count > 7: - print "\nError: failed to retrieve messages." - print "%s %s" % (r, d) - sys.exit(5) - sleep_time = math.pow(2, bad_count) - sys.stdout.write("\nServer responded with %s %s, will retry in %s seconds" % (r, d, str(sleep_time))) - time.sleep(sleep_time) # sleep 2 seconds, then 4, 8, 16, 32, 64, 128 - imapconn = gimaplib.ImapConnect(generateXOAuthString(options.email, options.service_account), options.debug) - imapconn.select(ALL_MAIL, readonly=True) - continue - break - except imaplib.IMAP4.abort, e: - print 'imaplib.abort error:%s, retrying...' % e - imapconn = gimaplib.ImapConnect(generateXOAuthString(options.email, options.service_account), options.debug) - imapconn.select(ALL_MAIL, readonly=True) - except socket.error, e: - print 'socket.error:%s, retrying...' % e - imapconn = gimaplib.ImapConnect(generateXOAuthString(options.email, options.service_account), options.debug) - imapconn.select(ALL_MAIL, readonly=True) - for results in d: - search_results = re.search('X-GM-LABELS \((.*)\) UID ([0-9]*) (FLAGS \(.*\))', results) - labels = shlex.split(search_results.group(1), posix=False) - uid = search_results.group(2) - message_flags_string = search_results.group(3) - message_flags = imaplib.ParseFlags(message_flags_string) - sqlcur.execute('DELETE FROM current_labels') - sqlcur.execute('DELETE FROM current_flags') - sqlcur.executemany( - 'INSERT INTO current_labels (label) VALUES (?)', - ((label,) for label in labels)) - sqlcur.executemany( - 'INSERT INTO current_flags (flag) VALUES (?)', - ((flag,) for flag in message_flags)) - sqlcur.execute("""DELETE FROM labels where message_num = - (SELECT message_num from uids where uid = ?) - AND label NOT IN current_labels""", ((uid),)) - sqlcur.execute("""DELETE FROM flags where message_num = - (SELECT message_num from uids where uid = ?) - AND flag NOT IN current_flags""", ((uid),)) - sqlcur.execute("""INSERT INTO labels (message_num, label) - SELECT message_num, label from uids, current_labels - WHERE uid = ? AND label NOT IN - (SELECT label FROM labels - WHERE message_num = uids.message_num)""", ((uid),)) - sqlcur.execute("""INSERT INTO flags (message_num, flag) - SELECT message_num, flag from uids, current_flags - WHERE uid = ? AND flag NOT IN - (SELECT flag FROM flags - WHERE message_num = uids.message_num)""", ((uid),)) - backed_up_messages += 1 - - sqlconn.commit() - restart_line() - sys.stdout.write("refreshed %s of %s messages" % (backed_up_messages, backup_count)) - sys.stdout.flush() - print "\n" - - # RESTORE # - elif options.action == 'restore': - print 'using IMAP folder %s' % options.use_folder - imapconn.select(options.use_folder) - resumedb = os.path.join(options.local_folder, - "%s-restored.sqlite" % options.email) - if options.noresume: - try: - os.remove(resumedb) - except OSError: - pass - except IOError: - pass - sqlcur.execute('ATTACH ? as resume', (resumedb,)) - sqlcur.executescript('''CREATE TABLE IF NOT EXISTS resume.restored_messages - (message_num INTEGER PRIMARY KEY); - CREATE TEMP TABLE skip_messages (message_num INTEGER PRIMARY KEY);''') - sqlcur.execute('''INSERT INTO skip_messages SELECT message_num from restored_messages''') - sqlcur.execute('''SELECT message_num, message_internaldate, message_filename FROM messages - WHERE message_num NOT IN skip_messages ORDER BY message_internaldate DESC''') # All messages - - messages_to_restore_results = sqlcur.fetchall() - restore_count = len(messages_to_restore_results) - current = 0 - created_labels = [] - for x in messages_to_restore_results: - restart_line() - current += 1 - message_filename = x[2] - sys.stdout.write("restoring message %s of %s from %s" % (current, restore_count, message_filename)) - sys.stdout.flush() - message_num = x[0] - message_internaldate = x[1] - message_internaldate_seconds = time.mktime(message_internaldate.timetuple()) - if not os.path.isfile(os.path.join(options.local_folder, message_filename)): - print 'WARNING! file %s does not exist for message %s' % (os.path.join(options.local_folder, message_filename), message_num) - print ' this message will be skipped.' - continue - f = open(os.path.join(options.local_folder, message_filename), 'rb') - full_message = f.read() - f.close() - full_message = full_message.replace('\x00', '') # No NULL chars - labels = [] - if not options.strip_labels: - labels_query = sqlcur.execute('SELECT DISTINCT label FROM labels WHERE message_num = ?', (message_num,)) - labels_results = sqlcur.fetchall() - for l in labels_results: - labels.append(l[0].replace('\\','\\\\').replace('"','\\"')) - if options.label_restored: - labels.append(options.label_restored) - for label in labels: - if label not in created_labels and label.find('/') != -1: # create parent labels - create_label = label - while True: - imapconn.create(create_label) - created_labels.append(create_label) - if create_label.find('/') == -1: - break - create_label = create_label[:create_label.rfind('/')] - flags_query = sqlcur.execute('SELECT DISTINCT flag FROM flags WHERE message_num = ?', (message_num,)) - flags_results = sqlcur.fetchall() - flags = [] - for f in flags_results: - flags.append(f[0]) - flags_string = ' '.join(flags) - while True: - try: - r, d = imapconn.append(options.use_folder, flags_string, message_internaldate_seconds, full_message) - if r != 'OK': - print '\nError: %s %s' % (r,d) - sys.exit(5) - try: - restored_uid = int(re.search('^[APPENDUID [0-9]* ([0-9]*)] \(Success\)$', d[0]).group(1)) - except AttributeError: - print '\nerror retrieving uid: %s: retrying...' % d - time.sleep(3) - imapconn = gimaplib.ImapConnect(generateXOAuthString(options.email, options.service_account), options.debug) - imapconn.select(ALL_MAIL) - if len(labels) > 0: - labels_string = '("'+'" "'.join(labels)+'")' - r, d = imapconn.uid('STORE', restored_uid, '+X-GM-LABELS', labels_string) - if r != 'OK': - print '\nGImap Set Message Labels Failed: %s %s' % (r, d) - sys.exit(33) - break - except imaplib.IMAP4.abort, e: - print '\nimaplib.abort error:%s, retrying...' % e - imapconn = gimaplib.ImapConnect(generateXOAuthString(options.email, options.service_account), options.debug) - imapconn.select(ALL_MAIL) - except socket.error, e: - print '\nsocket.error:%s, retrying...' % e - imapconn = gimaplib.ImapConnect(generateXOAuthString(options.email, options.service_account), options.debug) - imapconn.select(ALL_MAIL) - #Save the fact that it is completed - sqlconn.execute( - 'INSERT OR IGNORE INTO restored_messages (message_num) VALUES (?)', - (message_num,)) - sqlconn.commit() - sqlconn.execute('DETACH resume') - sqlconn.commit() - - # RESTORE-MBOX # - elif options.action == 'restore-mbox': - imapconn.select(options.use_folder) - resumedb = os.path.join(options.local_folder, - "%s-restored.sqlite" % options.email) - if options.noresume: - try: - os.remove(resumedb) - except OSError: - pass - except IOError: - pass - sqlcur.execute('ATTACH ? as mbox_resume', (resumedb,)) - sqlcur.executescript('''CREATE TABLE IF NOT EXISTS mbox_resume.restored_messages - (message TEXT PRIMARY KEY)''') - sqlcur.execute('''SELECT message FROM mbox_resume.restored_messages''') - messages_to_skip_results = sqlcur.fetchall() - messages_to_skip = [] - for a_message in messages_to_skip_results: - messages_to_skip.append(a_message[0]) - if os.name == 'windows' or os.name == 'nt': - divider = '\\' - else: - divider = '/' - created_labels = [] - for path, subdirs, files in os.walk(options.local_folder): - for filename in files: - if filename[-4:].lower() != u'.mbx' and filename[-5:].lower() != u'.mbox': - continue - file_path = '%s%s%s' % (path, divider, filename) - mbox = mailbox.mbox(file_path) - mbox_count = len(mbox.items()) - current = 0 - print "\nRestoring from %s" % file_path - for message in mbox: - current += 1 - message_marker = '%s-%s' % (file_path, current) - if message_marker in messages_to_skip: - continue - restart_line() - labels = message[u'X-Gmail-Labels'] - if labels != None and labels != u'' and not options.strip_labels: - bytes, encoding = email.header.decode_header(labels)[0] - if encoding != None: - try: - labels = bytes.decode(encoding) - except UnicodeDecodeError: - pass - else: - labels = labels.decode('string-escape') - labels = labels.split(u',') - else: - labels = [] - if options.label_restored: - labels.append(options.label_restored) - for label in labels: - if label not in created_labels and label.find('/') != -1: # create parent labels - create_label = label - while True: - imapconn.create(create_label) - created_labels.append(create_label) - if create_label.find('/') == -1: - break - create_label = create_label[:create_label.rfind('/')] - flags = [] - if u'Unread' in labels: - labels.remove(u'Unread') - else: - flags.append(u'\Seen') - if u'Starred' in labels: - flags.append(u'\Flagged') - labels.remove(u'Starred') - for bad_label in [u'Sent', u'Inbox', u'Important', u'Drafts', u'Chat', u'Muted', u'Trash', u'Spam']: - if bad_label in labels: - labels.remove(bad_label) - if bad_label == u'Chat': - labels.append(u'Restored Chats') - elif bad_label == u'Drafts': - labels.append(u'\\\\Draft') - else: - labels.append(u'\\\\%s' % bad_label) - escaped_labels = [] - for label in labels: - if label.find('\"') != -1: - escaped_labels.append(label.replace('\"', '\\"')) - else: - escaped_labels.append(label) - del message[u'X-Gmail-Labels'] - del message[u'X-GM-THRID'] - flags_string = ' '.join(flags) - msg_account, internal_datetime = message.get_from().split(' ', 1) - internal_datetime_seconds = time.mktime(email.utils.parsedate(internal_datetime)) - sys.stdout.write(" message %s of %s" % (current, mbox_count)) - sys.stdout.flush() - full_message = message.as_string() - while True: - try: - r, d = imapconn.append(options.use_folder, flags_string, internal_datetime_seconds, full_message) - if r != 'OK': - print '\nError: %s %s' % (r,d) - sys.exit(5) - restored_uid = int(re.search('^[APPENDUID [0-9]* ([0-9]*)] \(Success\)$', d[0]).group(1)) - if len(labels) > 0: - labels_string = '("'+'" "'.join(escaped_labels)+'")' - r, d = imapconn.uid('STORE', restored_uid, '+X-GM-LABELS', labels_string) - if r != 'OK': - print '\nGImap Set Message Labels Failed: %s %s' % (r, d) - sys.exit(33) - break - except imaplib.IMAP4.abort, e: - print '\nimaplib.abort error:%s, retrying...' % e - imapconn = gimaplib.ImapConnect(generateXOAuthString(options.email, options.service_account), options.debug) - imapconn.select(ALL_MAIL) - except socket.error, e: - print '\nsocket.error:%s, retrying...' % e - imapconn = gimaplib.ImapConnect(generateXOAuthString(options.email, options.service_account), options.debug) - imapconn.select(ALL_MAIL) - #Save the fact that it is completed - sqlconn.execute( - 'INSERT INTO restored_messages (message) VALUES (?)', - (message_marker,)) - sqlconn.commit() - sqlconn.execute('DETACH mbox_resume') - sqlconn.commit() - - # RESTORE-GROUP # - elif options.action == 'restore-group': - resumedb = os.path.join(options.local_folder, - "%s-restored.sqlite" % options.email) - if options.noresume: - try: - os.remove(resumedb) - except OSError: - pass - except IOError: - pass - sqlcur.execute('ATTACH ? as resume', (resumedb,)) - sqlcur.executescript('''CREATE TABLE IF NOT EXISTS resume.restored_messages - (message_num INTEGER PRIMARY KEY); - CREATE TEMP TABLE skip_messages (message_num INTEGER PRIMARY KEY);''') - sqlcur.execute('''INSERT INTO skip_messages SELECT message_num from restored_messages''') - sqlcur.execute('''SELECT message_num, message_internaldate, message_filename FROM messages - WHERE message_num NOT IN skip_messages ORDER BY message_internaldate DESC''') # All messages - messages_to_restore_results = sqlcur.fetchall() - restore_count = len(messages_to_restore_results) - if options.service_account: - if not options.use_admin: - print 'Error: --restore_group and --service_account require --user_admin to specify Google Apps Admin to utilize.' - sys.exit(5) - try: - f = file(getProgPath()+'privatekey.p12', 'rb') - key = f.read() - f.close() - service_account_name = service_account - except IOError: - json_string = file(getProgPath()+'privatekey.json', 'rb').read() - json_data = simplejson.loads(json_string) - key = json_data['private_key'] - service_account_name = json_data['client_email'] - scope = 'https://www.googleapis.com/auth/apps.groups.migration' - credentials = oauth2client.client.SignedJwtAssertionCredentials(options.service_account_name, key, scope=scope, prn=options.use_admin) - disable_ssl_certificate_validation = False - if os.path.isfile(getProgPath()+'noverifyssl.txt'): - disable_ssl_certificate_validation = True - http = httplib2.Http(ca_certs=getProgPath()+'cacert.pem', disable_ssl_certificate_validation=disable_ssl_certificate_validation) - if options.debug: - httplib2.debuglevel = 4 - http = credentials.authorize(http) - elif options.use_admin: - cfgFile = '%s%s.cfg' % (getProgPath(), options.use_admin) - f = open(cfgFile, 'rb') - token = simplejson.load(f) - f.close() - storage = oauth2client.file.Storage(cfgFile) - credentials = storage.get() - disable_ssl_certificate_validation = False - if os.path.isfile(getProgPath()+'noverifyssl.txt'): - disable_ssl_certificate_validation = True - http = httplib2.Http(ca_certs=getProgPath()+'cacert.pem', disable_ssl_certificate_validation=disable_ssl_certificate_validation) - if options.debug: - httplib2.debuglevel = 4 - http = credentials.authorize(http) - else: - print 'Error: restore-group requires that --use_admin is also specified.' - sys.exit(5) - gmig = apiclient.discovery.build('groupsmigration', 'v1', http=http) - current = 0 - for x in messages_to_restore_results: - restart_line() - current += 1 - sys.stdout.write("restoring message %s of %s from %s" % (current, restore_count, x[1])) - sys.stdout.flush() - message_num = x[0] - message_internaldate = x[1] - message_filename = x[2] - if not os.path.isfile(os.path.join(options.local_folder, message_filename)): - print 'WARNING! file %s does not exist for message %s' % (os.path.join(options.local_folder, message_filename), message_num) - print ' this message will be skipped.' - continue - f = open(os.path.join(options.local_folder, message_filename), 'rb') - full_message = f.read() - f.close() - media = apiclient.http.MediaFileUpload(os.path.join(options.local_folder, message_filename), mimetype='message/rfc822') - callGAPI(service=gmig.archive(), function='insert', groupId=options.email, media_body=media) - #Save the fact that it is completed - sqlconn.execute( -# 'INSERT OR IGNORE INTO restored_messages (message_num) VALUES (?)', - 'INSERT INTO restored_messages (message_num) VALUES (?)', - (message_num,)) - sqlconn.commit() - sqlconn.execute('DETACH resume') - sqlconn.commit() - - # COUNT - elif options.action == 'count': - print 'Using label %s' % options.use_folder - imapconn.select(options.use_folder, readonly=True) - messages_to_process = gimaplib.GImapSearch(imapconn, options.gmail_search) - messages_to_estimate = [] - #if we have a sqlcur , we'll compare messages to the db - #otherwise just estimate everything - for message_num in messages_to_process: - try: - sqlcur - if message_is_backed_up(message_num, sqlcur, sqlconn, options.local_folder): - continue - else: - messages_to_estimate.append(message_num) - except NameError: - messages_to_estimate.append(message_num) - estimate_count = len(messages_to_estimate) - total_size = float(0) - list_position = 0 - messages_at_once = 10000 - loop_count = 0 - print "%s,%s" % (options.email, estimate_count) - - # PURGE # - elif options.action == 'purge': - print 'Using label %s' % options.use_folder - imapconn.select(options.use_folder, readonly=False) - messages_to_process = gimaplib.GImapSearch(imapconn, options.gmail_search) - print 'Moving %s messages from All Mail to Trash for %s' % (len(messages_to_process), options.email) - messages_at_once = 1000 - loop_count = 0 - for working_messages in batch(messages_to_process, messages_at_once): - uid_string = ','.join(working_messages) - t, d = imapconn.uid('STORE', uid_string, '+X-GM-LABELS', '\\Trash') - try: - SPAM = label_mappings[u'\\Junk'] - except KeyError: - print 'Error: could not select the Spam folder. Please make sure it is not hidden from IMAP.' - sys.exit(2) - r, d = imapconn.select(SPAM, readonly=False) - if r == 'NO': - print "Error: Cannot select the Gmail \"Spam\" folder. Please make sure it is not hidden from IMAP." - sys.exit(3) - spam_uids = gimaplib.GImapSearch(imapconn, options.gmail_search) - print 'Purging %s Spam messages for %s' % (len(spam_uids), options.email) - for working_messages in batch(spam_uids, messages_at_once): - spam_uid_string = ','.join(working_messages) - t, d = imapconn.uid('STORE', spam_uid_string, '+FLAGS', '\Deleted') - imapconn.expunge() - try: - TRASH = label_mappings[u'\\Trash'] - except KeyError: - print 'Error: could not select the Trash folder. Please make sure it is not hidden from IMAP.' - sys.exit(4) - r, d = imapconn.select(TRASH, readonly=False) - if r == 'NO': - print "Error: Cannot select the Gmail \"Trash\" folder. Please make sure it is not hidden from IMAP." - sys.exit(3) - trash_uids = gimaplib.GImapSearch(imapconn, options.gmail_search) - print 'Purging %s Trash messages for %s' % (len(trash_uids), options.email) - for working_messages in batch(trash_uids, messages_at_once): - trash_uid_string = ','.join(working_messages) - t, d = imapconn.uid('STORE', trash_uid_string, '+FLAGS', '\Deleted') - imapconn.expunge() - - # PURGE-LABELS # - elif options.action == u'purge-labels': - pattern = options.gmail_search - if pattern == u'in:anywhere': - pattern = u'*' - pattern = r'%s' % pattern - r, existing_labels = imapconn.list(pattern=pattern) - for label_result in existing_labels: - if type(label_result) is not str: - continue - label = re.search(u'\" \"(.*)\"$', label_result).group(1) - if label == u'INBOX' or label == u'Deleted' or label[:7] == u'[Gmail]': - continue - - # ugly hacking of imaplib to keep it from overquoting/escaping - funcType = type(imapconn._quote) - imapconn._quote = funcType(just_quote, imapconn, imapconn) - - print u'Deleting label %s' % label - try: - r, d = imapconn.delete(label) - except imaplib.IMAP4.error, e: - print 'bad response: %s' % e - - # QUOTA # - elif options.action == 'quota': - result = imapconn.getquotaroot("INBOX")[1][1][0] - quota_used, quota_size = re.search('^".*" \(STORAGE ([0-9]*) ([0-9]*)\)$', result).groups() - quota_used = float(quota_used) / 1024.0 - quota_size = float(quota_size) / 1024.0 - used_pct = (quota_used / quota_size) * 100 - quota_used_term = 'MB' - quota_size_term = 'MB' - if quota_size > 1024.0: - quota_size = quota_size / 1024.0 - quota_size_term = 'GB' - if quota_size > 1024.0: - quota_size = quota_size / 1024.0 - quota_size_term = 'TB' - if quota_size > 1024.0: - quota_size = quota_size / 1024.0 - quota_size_term = 'PB' - if quota_used > 1024.0: - quota_used = quota_used / 1024.0 - quota_used_term = 'GB' - if quota_used > 1024.0: - quota_used = quota_used / 1024.0 - quota_used_term = 'TB' - if quota_used > 1024.0: - quota_used = quota_used / 1024.0 - quota_used_term = 'PB' - - print 'Total Google Storage: %.2f %s' % (quota_size, quota_size_term) - print 'Used Google Storage: %.2f %s' % (quota_used, quota_used_term) - print '%.2f%%' % used_pct - - # ESTIMATE # - elif options.action == 'estimate': - imapconn.select(options.use_folder, readonly=True) - messages_to_process = gimaplib.GImapSearch(imapconn, options.gmail_search) - messages_to_estimate = [] - #if we have a sqlcur , we'll compare messages to the db - #otherwise just estimate everything - for message_num in messages_to_process: - try: - sqlcur - if message_is_backed_up(message_num, sqlcur, sqlconn, options.local_folder): - continue - else: - messages_to_estimate.append(message_num) - except NameError: - messages_to_estimate.append(message_num) - estimate_count = len(messages_to_estimate) - total_size = float(0) - list_position = 0 - messages_at_once = 10000 - loop_count = 0 - print 'Email: %s' % options.email - print "Messages to estimate: %s" % estimate_count - estimated_messages = 0 - for working_messages in batch(messages_to_estimate, messages_at_once): - messages_size = get_message_size(imapconn, working_messages) - total_size = total_size + messages_size - if total_size > (1024 * 1024 * 1024): - math_size = total_size/(1024 * 1024 * 1024) - print_size = "%.2f GB" % math_size - elif total_size > (1024 * 1024): - math_size = total_size / (1024 * 1024) - print_size = "%.2f MB" % math_size - elif total_size > 1024: - math_size = total_size/1024 - print_size = "%.2f KB" % math_size - else: - print_size = "%.2f bytes" % total_size - if estimated_messages+messages_at_once < estimate_count: - estimated_messages = estimated_messages + messages_at_once - else: - estimated_messages = estimate_count - restart_line() - sys.stdout.write("Messages estimated: %s Estimated size: %s " % (estimated_messages, print_size)) - sys.stdout.flush() - time.sleep(1) - print "" - try: - sqlconn.close() - except NameError: - pass - try: - imapconn.logout() - except UnboundLocalError: # group-restore never does imapconn - pass - -if __name__ == '__main__': - reload(sys) - sys.setdefaultencoding(u'UTF-8') - if os.name == u'nt': - sys.argv = win32_unicode_argv() # cleanup sys.argv on Windows - doGYBCheckForUpdates() - try: - main(sys.argv[1:]) - except KeyboardInterrupt: - try: - sqlconn.commit() - sqlconn.close() - print - except NameError: - pass - sys.exit(4) +#!/usr/bin/env python3 +# +# Got Your Back +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""\n%s\n\nGot Your Back (GYB) is a command line tool which allows users to +backup and restore their Gmail. + +For more information, see http://git.io/gyb/ +""" + +global __name__, __author__, __email__, __version__, __license__ +__program_name__ = 'Got Your Back: Gmail Backup' +__author__ = 'Jay Lee' +__email__ = 'jay0lee@gmail.com' +__version__ = '0.40' +__license__ = 'Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0)' +__website__ = 'http://git.io/gyb' +__db_schema_version__ = '6' +__db_schema_min_version__ = '6' #Minimum for restore + +global extra_args, options, allLabelIds, allLabels, gmail +extra_args = {'prettyPrint': False} +allLabelIds = dict() +allLabels = dict() + +import argparse +import sys +import os +import os.path +import time +import random +import struct +import platform +import datetime +import sqlite3 +import email +import mailbox +import re +from itertools import islice, chain +import base64 +import json + +try: + import json as simplejson +except ImportError: + import simplejson + +import httplib2 +import oauth2client.client +import oauth2client.file +import oauth2client.tools +import googleapiclient +import googleapiclient.discovery +import googleapiclient.errors + +def SetupOptionParser(argv): + parser = argparse.ArgumentParser(add_help=False) + #parser.usage = parser.print_help() + parser.add_argument('--email', + dest='email', + help='Full email address of user or group to act against') + action_choices = ['backup','restore', 'restore-group', 'restore-mbox', + 'count', 'purge', 'purge-labels', 'estimate', 'quota', 'reindex', 'revoke'] + parser.add_argument('--action', + choices=action_choices, + dest='action', + default='backup', + help='Action to perform. Default is backup.') + parser.add_argument('--search', + dest='gmail_search', + default=None, + help='Optional: On backup, estimate, count and purge, Gmail search to \ +scope operation against') + parser.add_argument('--local-folder', + dest='local_folder', + help='Optional: On backup, restore, estimate, local folder to use. \ +Default is GYB-GMail-Backup-', + default='XXXuse-email-addressXXX') + parser.add_argument('--label-restored', + action='append', + dest='label_restored', + help='Optional: On restore, all messages will additionally receive \ +this label. For example, "--label_restored gyb-restored" will label all \ +uploaded messages with a gyb-restored label.') + parser.add_argument('--strip-labels', + dest='strip_labels', + action='store_true', + default=False, + help='Optional: On restore and restore-mbox, strip existing labels from \ +messages except for those explicitly declared with the --label-restored \ +parameter.') + parser.add_argument('--vault', + action='store_true', + default=None, + dest='vault', + help='Optional: On restore and restore-mbox, restored messages will not be\ +visible in user\'s Gmail but are subject to Vault discovery/retention.') + parser.add_argument('--service-account', + action='store_true', + dest='service_account', + help='Google Apps Business and Education only. Use OAuth 2.0 Service \ +Account to authenticate.') + parser.add_argument('--use-admin', + dest='use_admin', + help='Optional: On restore-group, authenticate as this admin user.') + parser.add_argument('--spam-trash', + dest='spamtrash', + action='store_true', + help='Optional: Include Spam and Trash folders in backup, estimate and count actions. This is always enabled for purge.') + parser.add_argument('--batch-size', + dest='batch_size', + metavar='{1 - 100}', + type=int, + choices=list(range(1,101)), + default=0, # default of 0 means use per action default + help='Optional: Sets the number of batch operations to perform at once.') + parser.add_argument('--noresume', + action='store_true', + help='Optional: On restores, start from beginning. Default is to resume \ +where last restore left off.') + parser.add_argument('--fast-restore', + action='store_true', + dest='fast_restore', + help='Optional: On restores, use the fast method. WARNING: using this \ +method breaks Gmail deduplication and threading.') + parser.add_argument('--fast-incremental', + dest='refresh', + action='store_false', + default=True, + help='Optional: On backup, skips refreshing labels for existing message') + parser.add_argument('--debug', + action='store_true', + dest='debug', + help='Turn on verbose debugging and connection information \ +(troubleshooting)') + parser.add_argument('--version', + action='store_true', + dest='version', + help='print GYB version and quit') + parser.add_argument('--help', + action='help', + help='Display this message.') + return parser.parse_args(argv) + +def getProgPath(): + if os.path.abspath('/') != -1: + divider = '/' + else: + divider = '\\' + return os.path.dirname(os.path.realpath(sys.argv[0]))+divider + +class cmd_flags(object): + def __init__(self): + self.short_url = True + self.noauth_local_webserver = False + self.logging_level = 'ERROR' + self.auth_host_name = 'localhost' + self.auth_host_port = [8080, 9090] + +def requestOAuthAccess(): + if options.use_admin: + auth_as = options.use_admin + else: + auth_as = options.email + CLIENT_SECRETS = getProgPath()+'client_secrets.json' + MISSING_CLIENT_SECRETS_MESSAGE = """ +WARNING: Please configure OAuth 2.0 + +To make GYB run you will need to populate the client_secrets.json file +found at: + + %s + +with information from the APIs Console https://console.developers.google.com. + +""" % (CLIENT_SECRETS) + cfgFile = '%s%s.cfg' % (getProgPath(), auth_as) + storage = oauth2client.file.Storage(cfgFile) + credentials = storage.get() + flags = cmd_flags() + if os.path.isfile(getProgPath()+'nobrowser.txt'): + flags.noauth_local_webserver = True + if credentials is None or credentials.invalid: + certFile = getProgPath()+'cacert.pem' + disable_ssl_certificate_validation = False + if os.path.isfile(getProgPath()+'noverifyssl.txt'): + disable_ssl_certificate_validation = True + http = httplib2.Http(ca_certs=certFile, + disable_ssl_certificate_validation=disable_ssl_certificate_validation) + possible_scopes = ['https://www.googleapis.com/auth/gmail.modify', + # Gmail modify + + 'https://www.googleapis.com/auth/gmail.readonly', + # Gmail readonly + + 'https://www.googleapis.com/auth/gmail.insert \ +https://www.googleapis.com/auth/gmail.labels', + # insert and labels + + 'https://mail.google.com/', + # Gmail Full Access + + '', + # No Gmail + + 'https://www.googleapis.com/auth/apps.groups.migration', + # Groups Archive Restore + + 'https://www.googleapis.com/auth/drive.appdata'] + # Drive app config (used for quota) + + selected_scopes = ['*', ' ', ' ', ' ', ' ', ' ', '*'] + menu = '''Select the actions you wish GYB to be able to perform for %s + +[%s] 0) Gmail Backup And Restore - read/write mailbox access +[%s] 1) Gmail Backup Only - read-only mailbox access +[%s] 2) Gmail Restore Only - write-only mailbox access and label management +[%s] 3) Gmail Full Access - read/write mailbox access and message purge +[%s] 4) No Gmail Access + +[%s] 5) Groups Restore - write to Google Apps Groups Archive +[%s] 6) Storage Quota - Drive app config scope used for --action quota + + 7) Continue +''' + os.system(['clear', 'cls'][os.name == 'nt']) + while True: + selection = input(menu % tuple([auth_as]+selected_scopes)) + try: + if int(selection) > -1 and int(selection) <= 6: + if selected_scopes[int(selection)] == ' ': + selected_scopes[int(selection)] = '*' + if int(selection) > -1 and int(selection) <= 4: + for i in range(0,5): + if i == int(selection): + continue + selected_scopes[i] = ' ' + else: + selected_scopes[int(selection)] = ' ' + elif selection == '7': + at_least_one = False + for i in range(0, len(selected_scopes)): + if selected_scopes[i] in ['*',]: + if i == 4: + continue + at_least_one = True + if at_least_one: + break + else: + os.system(['clear', 'cls'][os.name == 'nt']) + print("YOU MUST SELECT AT LEAST ONE SCOPE!\n") + continue + else: + os.system(['clear', 'cls'][os.name == 'nt']) + print('NOT A VALID SELECTION!\n') + continue + os.system(['clear', 'cls'][os.name == 'nt']) + except ValueError: + os.system(['clear', 'cls'][os.name == 'nt']) + print('NOT A VALID SELECTION!\n') + continue + scopes = ['email',] + for i in range(0, len(selected_scopes)): + if selected_scopes[i] == '*': + scopes.append(possible_scopes[i]) + FLOW = oauth2client.client.flow_from_clientsecrets(CLIENT_SECRETS, + scope=scopes, message=MISSING_CLIENT_SECRETS_MESSAGE, login_hint=auth_as) + credentials = oauth2client.tools.run_flow(flow=FLOW, storage=storage, + flags=flags, http=http) + certFile = getProgPath()+'cacert.pem' + disable_ssl_certificate_validation = False + if os.path.isfile(getProgPath()+'noverifyssl.txt'): + disable_ssl_certificate_validation = True + http = httplib2.Http(ca_certs=certFile, + disable_ssl_certificate_validation=disable_ssl_certificate_validation) + +def doGYBCheckForUpdates(): + import urllib.request, urllib.error, urllib.parse, calendar + no_update_check_file = getProgPath()+'noupdatecheck.txt' + last_update_check_file = getProgPath()+'lastcheck.txt' + if os.path.isfile(no_update_check_file): return + try: + current_version = float(__version__) + except ValueError: + return + if os.path.isfile(last_update_check_file): + f = open(last_update_check_file, 'rb') + last_check_time = int(f.readline()) + f.close() + else: + last_check_time = 0 + now_time = calendar.timegm(time.gmtime()) + one_week_ago_time = now_time - 604800 + if last_check_time > one_week_ago_time: return + try: + checkUrl = 'https://gyb-update.appspot.com/latest-version.txt?v=%s' + c = urllib.request.urlopen(checkUrl % (__version__,)) + try: + latest_version = float(c.read()) + except ValueError: + return + if latest_version <= current_version: + f = open(last_update_check_file, 'wb') + f.write(str(now_time)) + f.close() + return + announceUrl = 'https://gyb-update.appspot.com/\ +latest-version-announcement.txt?v=%s' + a = urllib.request.urlopen(announceUrl % (__version__,)) + announcement = a.read() + sys.stderr.write('\nThere\'s a new version of GYB!!!\n\n') + sys.stderr.write(announcement) + visit_gyb = input("\n\nHit Y to visit the GYB website and download \ +the latest release. Hit Enter to just continue with this boring old version.\ + GYB won't bother you with this announcemnt for 1 week or you can create a \ +file named %s and GYB won't ever check for updates: " % no_update_check_file) + if visit_gyb.lower() == 'y': + import webbrowser + webbrowser.open(__website__) + print('GYB is now exiting so that you can overwrite this old version \ +with the latest release') + sys.exit(0) + f = open(last_update_check_file, 'wb') + f.write(str(now_time)) + f.close() + except urllib.error.HTTPError: + return + except urllib.error.URLError: + return + +def getAPIVer(api): + if api == 'oauth2': + return 'v2' + elif api == 'gmail': + return 'v1' + elif api == 'groupsmigration': + return 'v1' + elif api == 'drive': + return 'v2' + return 'v1' + +def getAPIScope(api): + if api == 'gmail': + return ['https://mail.google.com/'] + elif api == 'groupsmigration': + return ['https://www.googleapis.com/auth/apps.groups.migration'] + elif api == 'drive': + return ['https://www.googleapis.com/auth/drive.appdata'] + +def buildGAPIObject(api): + if options.use_admin: + auth_as = options.use_admin + else: + auth_as = options.email + oauth2file = '%s%s.cfg' % (getProgPath(), auth_as) + storage = oauth2client.file.Storage(oauth2file) + credentials = storage.get() + if credentials is None or credentials.invalid: + doRequestOAuth() + credentials = storage.get() + credentials.user_agent = getGYBVersion(' | ') + disable_ssl_certificate_validation = False + if os.path.isfile(getProgPath()+'noverifyssl.txt'): + disable_ssl_certificate_validation = True + http = httplib2.Http(ca_certs=getProgPath()+'cacert.pem', + disable_ssl_certificate_validation=disable_ssl_certificate_validation) + if options.debug: + httplib2.debuglevel = 4 + extra_args['prettyPrint'] = True + if os.path.isfile(getProgPath()+'extra-args.txt'): + import configparser + config = configparser.ConfigParser() + config.optionxform = str + config.read(getProgPath()+'extra-args.txt') + extra_args.update(dict(config.items('extra-args'))) + http = credentials.authorize(http) + version = getAPIVer(api) + try: + return googleapiclient.discovery.build(api, version, http=http) + except googleapiclient.errors.UnknownApiNameOrVersion: + disc_file = getProgPath()+'%s-%s.json' % (api, version) + if os.path.isfile(disc_file): + f = file(disc_file, 'rb') + discovery = f.read() + f.close() + return googleapiclient.discovery.build_from_document(discovery, + base='https://www.googleapis.com', http=http) + else: + print('No online discovery doc and %s does not exist locally' + % disc_file) + raise + +def buildGAPIServiceObject(api, soft_errors=False): + global extra_args + if options.use_admin: + auth_as = options.use_admin + else: + auth_as = options.email + oauth2servicefile = getProgPath()+'oauth2service' + oauth2servicefilejson = '%s.json' % oauth2servicefile + try: + json_string = open(oauth2servicefilejson, 'rb').read() + except IOError as e: + print('Error: %s' % e) + print('') + print('Please follow the instructions at:\n\nhttps://github.com/jay0lee/go\ +t-your-back/wiki#google-apps-business-and-education-admins-backup-restore-and-\ +estimate-users-and-restore-to-groups\n\nto setup a Service Account') + sys.exit(6) + json_data = json.loads(json_string) + SERVICE_ACCOUNT_EMAIL = json_data['client_email'] + SERVICE_ACCOUNT_CLIENT_ID = json_data['client_id'] + key = json_data['private_key'].encode('utf-8') + scope = getAPIScope(api) + credentials = oauth2client.client.SignedJwtAssertionCredentials( + SERVICE_ACCOUNT_EMAIL, key, scope=scope, sub=auth_as) + credentials.user_agent = getGYBVersion(' | ') + disable_ssl_certificate_validation = False + if os.path.isfile(getProgPath()+'noverifyssl.txt'): + disable_ssl_certificate_validation = True + http = httplib2.Http(ca_certs=getProgPath()+'cacert.pem', + disable_ssl_certificate_validation=disable_ssl_certificate_validation) + if options.debug: + httplib2.debuglevel = 4 + extra_args['prettyPrint'] = True + if os.path.isfile(getProgPath()+'extra-args.txt'): + import configparser + config = configparser.ConfigParser() + config.optionxform = str + config.read(getGamPath()+'extra-args.txt') + extra_args.update(dict(config.items('extra-args'))) + http = credentials.authorize(http) + version = getAPIVer(api) + try: + return googleapiclient.discovery.build(api, version, http=http) + except oauth2client.client.AccessTokenRefreshError as e: + if e.message in ['access_denied', 'unauthorized_client: Unauthorized \ +client or scope in request.']: + print('Error: Access Denied. Please make sure the Client Name:\ +\n\n%s\n\nis authorized for the API Scope(s):\n\n%s\n\nThis can be \ +configured in your Control Panel under:\n\nSecurity -->\nAdvanced \ +Settings -->\nManage third party OAuth Client access' +% (SERVICE_ACCOUNT_CLIENT_ID, ','.join(scope))) + sys.exit(5) + else: + print('Error: %s' % e) + if soft_errors: + return False + sys.exit(4) + +def callGAPI(service, function, soft_errors=False, throw_reasons=[], **kwargs): + method = getattr(service, function) + retries = 3 + parameters = kwargs.copy() + parameters.update(extra_args) + for n in range(1, retries+1): + try: + return method(**parameters).execute() + except googleapiclient.errors.HttpError as e: + error = simplejson.loads(e.content.decode('utf-8')) + try: + reason = error['error']['errors'][0]['reason'] + http_status = error['error']['code'] + message = error['error']['errors'][0]['message'] + if reason in throw_reasons: + raise + if n != retries and reason in ['rateLimitExceeded', + 'userRateLimitExceeded', 'backendError']: + wait_on_fail = (2 ** n) if (2 ** n) < 60 else 60 + randomness = float(random.randint(1,1000)) / 1000 + wait_on_fail = wait_on_fail + randomness + if n > 3: + sys.stderr.write('\nTemp error %s. Backing off %s seconds...' + % (reason, int(wait_on_fail))) + time.sleep(wait_on_fail) + if n > 3: + sys.stderr.write('attempt %s/%s\n' % (n+1, retries)) + continue + sys.stderr.write('\n%s: %s - %s\n' % (http_status, message, reason)) + if soft_errors: + sys.stderr.write(' - Giving up.\n') + return + else: + sys.exit(int(http_status)) + except KeyError: + sys.stderr.write('Unknown Error: %s' % e) + sys.exit(1) + except oauth2client.client.AccessTokenRefreshError as e: + sys.stderr.write('Error: Authentication Token Error - %s' % e) + sys.exit(403) + +def callGAPIpages(service, function, items='items', + nextPageToken='nextPageToken', page_message=None, message_attribute=None, + **kwargs): + pageToken = None + all_pages = list() + total_items = 0 + while True: + this_page = callGAPI(service=service, function=function, + pageToken=pageToken, **kwargs) + if not this_page: + this_page = {items: []} + try: + page_items = len(this_page[items]) + except KeyError: + page_items = 0 + total_items += page_items + if page_message: + show_message = page_message + try: + show_message = show_message.replace('%%num_items%%', str(page_items)) + except (IndexError, KeyError): + show_message = show_message.replace('%%num_items%%', '0') + try: + show_message = show_message.replace('%%total_items%%', + str(total_items)) + except (IndexError, KeyError): + show_message = show_message.replace('%%total_items%%', '0') + if message_attribute: + try: + show_message = show_message.replace('%%first_item%%', + str(this_page[items][0][message_attribute])) + show_message = show_message.replace('%%last_item%%', + str(this_page[items][-1][message_attribute])) + except (IndexError, KeyError): + show_message = show_message.replace('%%first_item%%', '') + show_message = show_message.replace('%%last_item%%', '') + rewrite_line(show_message) + try: + all_pages += this_page[items] + pageToken = this_page[nextPageToken] + if pageToken == '': + return all_pages + except (IndexError, KeyError): + if page_message: + sys.stderr.write('\n') + return all_pages + +def message_is_backed_up(message_num, sqlcur, sqlconn, backup_folder): + try: + sqlcur.execute(''' + SELECT message_filename FROM uids NATURAL JOIN messages + where uid = ?''', ((message_num),)) + except sqlite3.OperationalError as e: + if e.message == 'no such table: messages': + print("\n\nError: your backup database file appears to be corrupted.") + else: + print("SQL error:%s" % e) + sys.exit(8) + sqlresults = sqlcur.fetchall() + for x in sqlresults: + filename = x[0] + if os.path.isfile(os.path.join(backup_folder, filename)): + return True + return False + +def get_db_settings(sqlcur): + try: + sqlcur.execute('SELECT name, value FROM settings') + db_settings = dict(sqlcur) + return db_settings + except sqlite3.OperationalError as e: + if e.message == 'no such table: settings': + print("\n\nSorry, this version of GYB requires version %s of the \ +database schema. Your backup folder database does not have a version." + % (__db_schema_version__)) + sys.exit(6) + else: + print("%s" % e) + +def check_db_settings(db_settings, action, user_email_address): + if (db_settings['db_version'] < __db_schema_min_version__ or + db_settings['db_version'] > __db_schema_version__): + print("\n\nSorry, this backup folder was created with version %s of the \ +database schema while GYB %s requires version %s - %s for restores" +% (db_settings['db_version'], __version__, __db_schema_min_version__, +__db_schema_version__)) + sys.exit(4) + + # Only restores are allowed to use a backup folder started with another + # account (can't allow 2 Google Accounts to backup/estimate from same folder) + if action not in ['restore', 'restore-group', 'restore-mbox']: + if user_email_address.lower() != db_settings['email_address'].lower(): + print("\n\nSorry, this backup folder should only be used with the %s \ +account that it was created with for incremental backups. You specified the\ + %s account" % (db_settings['email_address'], user_email_address)) + sys.exit(5) + +def convertDB(sqlconn, uidvalidity, oldversion): + print("Converting database") + try: + with sqlconn: + if oldversion < '3': + # Convert to schema 3 + sqlconn.executescript(''' + BEGIN; + CREATE TABLE uids + (message_num INTEGER, uid INTEGER PRIMARY KEY); + INSERT INTO uids (uid, message_num) + SELECT message_num as uid, message_num FROM messages; + CREATE INDEX labelidx ON labels (message_num); + CREATE INDEX flagidx ON flags (message_num); + ''') + if oldversion < '4': + # Convert to schema 4 + sqlconn.execute(''' + ALTER TABLE messages ADD COLUMN rfc822_msgid TEXT; + ''') + if oldversion < '5': + # Convert to schema 5 + sqlconn.executescript(''' + DROP INDEX labelidx; + DROP INDEX flagidx; + CREATE UNIQUE INDEX labelidx ON labels (message_num, label); + CREATE UNIQUE INDEX flagidx ON flags (message_num, flag); + ''') + if oldversion < '6': + # Convert to schema 6 + getMessageIDs(sqlconn, options.local_folder) + rebuildUIDTable(sqlconn) + sqlconn.executemany('REPLACE INTO settings (name, value) VALUES (?,?)', + (('uidvalidity',uidvalidity), + ('db_version', __db_schema_version__)) ) + sqlconn.commit() + except sqlite3.OperationalError as e: + print("Conversion error: %s" % e.message) + + print("GYB database converted to version %s" % __db_schema_version__) + +def getMessageIDs (sqlconn, backup_folder): + sqlcur = sqlconn.cursor() + header_parser = email.parser.HeaderParser() + for message_num, filename in sqlconn.execute(''' + SELECT message_num, message_filename FROM messages + WHERE rfc822_msgid IS NULL'''): + message_full_filename = os.path.join(backup_folder, filename) + if os.path.isfile(message_full_filename): + f = open(message_full_filename, 'rb') + msgid = header_parser.parse(f, True).get('message-id') or '' + f.close() + sqlcur.execute( + 'UPDATE messages SET rfc822_msgid = ? WHERE message_num = ?', + (msgid, message_num)) + sqlconn.commit() + +def rebuildUIDTable(sqlconn): + pass + +def doesTokenMatchEmail(): + if options.use_admin: + auth_as = options.use_admin + else: + auth_as = options.email + oa2 = buildGAPIObject('oauth2') + user_info = callGAPI(service=oa2.userinfo(), function='get', + fields='email') + if user_info['email'].lower() == auth_as.lower(): + return True + print("Error: you did not authorize the OAuth token in the browser with the \ +%s Google Account. Please make sure you are logged in to the correct account \ +when authorizing the token in the browser." % auth_as) + cfgFile = '%s%s.cfg' % (getProgPath(), auth_as) + os.remove(cfgFile) + return False + +def restart_line(): + sys.stdout.write('\r') + sys.stdout.flush() + +def rewrite_line(mystring): + sys.stdout.write('\r') + sys.stdout.flush() + padding_length = 80 - len(mystring) + padding = ' ' * padding_length + sys.stdout.write(mystring) + sys.stdout.write(padding) + sys.stdout.flush() + +def initializeDB(sqlcur, sqlconn, email): + sqlcur.executescript(''' + CREATE TABLE messages(message_num INTEGER PRIMARY KEY, + message_filename TEXT, + message_internaldate TIMESTAMP); + CREATE TABLE labels (message_num INTEGER, label TEXT); + CREATE TABLE uids (message_num INTEGER, uid TEXT PRIMARY KEY); + CREATE TABLE settings (name TEXT PRIMARY KEY, value TEXT); + CREATE UNIQUE INDEX labelidx ON labels (message_num, label); + ''') + sqlcur.executemany('INSERT INTO settings (name, value) VALUES (?, ?)', + (('email_address', email), + ('db_version', __db_schema_version__))) + sqlconn.commit() + +def getGYBVersion(divider="\n"): + return ('Got Your Back %s~DIV~%s~DIV~%s - %s~DIV~Python %s.%s.%s %s-bit \ +%s~DIV~%s %s' % (__version__, __website__, __author__, __email__, +sys.version_info[0], sys.version_info[1], sys.version_info[2], +struct.calcsize('P')*8, sys.version_info[3], platform.platform(), +platform.machine())).replace('~DIV~', divider) + +def labelIdsToLabels(labelIds): + global allLabelIds, gmail + labels = list() + for labelId in labelIds: + if labelId not in allLabelIds: + # refresh allLabelIds from Google + label_results = callGAPI(service=gmail.users().labels(), function='list', + userId='me', fields='labels(name,id,type)') + allLabelIds = dict() + for a_label in label_results['labels']: + if a_label['type'] == 'system': + allLabelIds[a_label['id']] = a_label['id'] + else: + allLabelIds[a_label['id']] = a_label['name'] + try: + labels.append(allLabelIds[labelId]) + except KeyError: + pass + return labels + +def labelsToLabelIds(labels): + global allLabels + if len(allLabels) < 1: # first fetch of all labels from Google + label_results = callGAPI(service=gmail.users().labels(), function='list', + userId='me', fields='labels(name,id,type)') + allLabels = dict() + for a_label in label_results['labels']: + if a_label['type'] == 'system': + allLabels[a_label['id']] = a_label['id'] + else: + allLabels[a_label['name']] = a_label['id'] + labelIds = list() + for label in labels: + if label == 'CHAT': + label = 'Chat-Restored' + if label not in allLabels: + # create new label (or get it's id if it exists) + label_results = callGAPI(service=gmail.users().labels(), function='create', + body={'labelListVisibility': 'labelShow', + 'messageListVisibility': 'show', 'name': label}, + userId='me', fields='id') + allLabels[label] = label_results['id'] + try: + labelIds.append(allLabels[label]) + except KeyError: + pass + if label.find('/') != -1: + # make sure to create parent labels for proper nesting + parent_label = label[:label.rfind('/')] + while True: + if not parent_label in allLabels: + label_result = callGAPI(service=gmail.users().labels(), + function='create', userId='me', body={'name': parent_label}) + allLabels[parent_label] = label_result['id'] + if parent_label.find('/') == -1: + break + parent_label = parent_label[:parent_label.rfind('/')] + return labelIds + +def refresh_message(request_id, response, exception): + if exception is not None: + raise exception + else: + if 'labelIds' in response: + labels = labelIdsToLabels(response['labelIds']) + sqlcur.execute('DELETE FROM current_labels') + sqlcur.executemany( + 'INSERT INTO current_labels (label) VALUES (?)', + ((label,) for label in labels)) + sqlcur.execute("""DELETE FROM labels where message_num = + (SELECT message_num from uids where uid = ?) + AND label NOT IN current_labels""", ((response['id']),)) + sqlcur.execute("""INSERT INTO labels (message_num, label) + SELECT message_num, label from uids, current_labels + WHERE uid = ? AND label NOT IN + (SELECT label FROM labels + WHERE message_num = uids.message_num)""", + ((response['id']),)) + +def restored_message(request_id, response, exception): + if exception is not None: + raise exception + else: + sqlconn.execute( + '''INSERT OR IGNORE INTO restored_messages (message_num) VALUES (?)''', + (request_id,)) + +def purged_message(request_id, response, exception): + if exception is not None: + raise exception + +def estimate_message(request_id, response, exception): + global message_size_estimate + if exception is not None: + raise exception + else: + this_message_size = int(response['sizeEstimate']) + message_size_estimate += this_message_size + +def backup_message(request_id, response, exception): + if exception is not None: + print(exception) + else: + if 'labelIds' in response: + labelIds = response['labelIds'] + else: + labelIds = list() + labels = labelIdsToLabels(labelIds) + message_file_name = "%s.eml" % (response['id']) + message_time = int(response['internalDate'])/1000 + message_date = time.gmtime(message_time) + time_for_sqlite = datetime.datetime.fromtimestamp(message_time) + message_rel_path = os.path.join(str(message_date.tm_year), + str(message_date.tm_mon), + str(message_date.tm_mday)) + message_rel_filename = os.path.join(message_rel_path, + message_file_name) + message_full_path = os.path.join(options.local_folder, + message_rel_path) + message_full_filename = os.path.join(options.local_folder, + message_rel_filename) + if not os.path.isdir(message_full_path): + os.makedirs(message_full_path) + f = open(message_full_filename, 'wb') + raw_message = str(response['raw']) + full_message = base64.urlsafe_b64decode(raw_message) + f.write(full_message) + f.close() + sqlcur.execute(""" + INSERT INTO messages ( + message_filename, + message_internaldate) VALUES (?, ?)""", + (message_rel_filename, + time_for_sqlite)) + message_num = sqlcur.lastrowid + sqlcur.execute(""" + REPLACE INTO uids (message_num, uid) VALUES (?, ?)""", + (message_num, response['id'])) + for label in labels: + sqlcur.execute(""" + INSERT INTO labels (message_num, label) VALUES (?, ?)""", + (message_num, label)) + +def bytes_to_larger(myval): + myval = int(myval) + mysize = 'b' + if myval > 1024: + myval = myval / 1024 + mysize = 'kb' + if myval > 1024: + myval = myval / 1024 + mysize = 'mb' + if myval > 1024: + myval = myval / 1024 + mysize = 'gb' + if myval > 1024: + myval = myval / 1024 + mysize = 'tb' + return '%.2f%s' % (myval, mysize) + +def main(argv): + global options, gmail + options = SetupOptionParser(argv) + if options.version: + print(getGYBVersion()) + sys.exit(0) + if not options.email: + print('ERROR: --email is required.') + sys.exit(1) + if options.local_folder == 'XXXuse-email-addressXXX': + options.local_folder = "GYB-GMail-Backup-%s" % options.email + if not options.service_account: # 3-Legged OAuth + requestOAuthAccess() + if not doesTokenMatchEmail(): + sys.exit(9) + gmail = buildGAPIObject('gmail') + else: + gmail = buildGAPIServiceObject('gmail') + if not os.path.isdir(options.local_folder): + if options.action in ['backup',]: + os.mkdir(options.local_folder) + elif options.action in ['restore', 'restore-group']: + print('Error: Folder %s does not exist. Cannot restore.' + % options.local_folder) + sys.exit(3) + + sqldbfile = os.path.join(options.local_folder, 'msg-db.sqlite') + # Do we need to initialize a new database? + newDB = (not os.path.isfile(sqldbfile)) and \ + (options.action in ['backup', 'restore-mbox']) + + # If we're not doing a estimate or if the db file actually exists we open it + # (creates db if it doesn't exist) + if options.action not in ['count', 'purge', 'purge-labels', + 'quota', 'revoke']: + if options.action not in ['estimate'] or os.path.isfile(sqldbfile): + print("\nUsing backup folder %s" % options.local_folder) + global sqlconn + global sqlcur + sqlconn = sqlite3.connect(sqldbfile, + detect_types=sqlite3.PARSE_DECLTYPES) + sqlconn.text_factory = str + sqlcur = sqlconn.cursor() + if newDB: + initializeDB(sqlcur, sqlconn, options.email) + db_settings = get_db_settings(sqlcur) + check_db_settings(db_settings, options.action, options.email) + if options.action not in ['restore', 'restore-group', 'restore-mbox']: + if db_settings['db_version'] < __db_schema_version__: + convertDB(sqlconn, db_settings['db_version']) + db_settings = get_db_settings(sqlcur) + if options.action == 'reindex': + getMessageIDs(sqlconn, options.local_folder) + rebuildUIDTable(sqlconn) + sqlconn.commit() + sys.exit(0) + + # BACKUP # + if options.action == 'backup': + if options.batch_size == 0: + options.batch_size = 100 + page_message = 'Got %%total_items%% Message IDs' + messages_to_process = callGAPIpages(service=gmail.users().messages(), + function='list', items='messages', page_message=page_message, + userId='me', includeSpamTrash=options.spamtrash, q=options.gmail_search, + fields='nextPageToken,messages/id') + backup_path = options.local_folder + if not os.path.isdir(backup_path): + os.mkdir(backup_path) + messages_to_backup = [] + messages_to_refresh = [] + #Determine which messages from the search we haven't processed before. + print("GYB needs to examine %s messages" % len(messages_to_process)) + for message_num in messages_to_process: + if not newDB and message_is_backed_up(message_num['id'], sqlcur, sqlconn, + options.local_folder): + messages_to_refresh.append(message_num['id']) + else: + messages_to_backup.append(message_num['id']) + print("GYB already has a backup of %s messages" % + (len(messages_to_process) - len(messages_to_backup))) + backup_count = len(messages_to_backup) + print("GYB needs to backup %s messages" % backup_count) + backed_up_messages = 0 + gbatch = googleapiclient.http.BatchHttpRequest() + for a_message in messages_to_backup: + gbatch.add(gmail.users().messages().get(userId='me', + id=a_message, format='raw', + fields='id,labelIds,internalDate,raw'), + callback=backup_message) + backed_up_messages += 1 + if len(gbatch._order) == options.batch_size: + gbatch.execute() + gbatch = googleapiclient.http.BatchHttpRequest() + sqlconn.commit() + rewrite_line("backed up %s of %s messages" % + (backed_up_messages, backup_count)) + if len(gbatch._order) > 0: + gbatch.execute() + sqlconn.commit() + rewrite_line("backed up %s of %s messages" % + (backed_up_messages, backup_count)) + print("\n") + + if not options.refresh: + messages_to_refresh = [] + refreshed_messages = 0 + refresh_count = len(messages_to_refresh) + print("GYB needs to refresh %s messages" % refresh_count) + sqlcur.executescript(""" + CREATE TEMP TABLE current_labels (label TEXT); + """) + gbatch = googleapiclient.http.BatchHttpRequest() + for a_message in messages_to_refresh: + gbatch.add(gmail.users().messages().get(userId='me', + id=a_message, format='minimal', + fields='id,labelIds'), + callback=refresh_message) + refreshed_messages += 1 + if len(gbatch._order) == options.batch_size: + gbatch.execute() + gbatch = googleapiclient.http.BatchHttpRequest() + sqlconn.commit() + rewrite_line("refreshed %s of %s messages" % + (refreshed_messages, refresh_count)) + if len(gbatch._order) > 0: + gbatch.execute() + sqlconn.commit() + rewrite_line("refreshed %s of %s messages" % + (refreshed_messages, refresh_count)) + print("\n") + + # RESTORE # + elif options.action == 'restore': + if options.batch_size == 0: + options.batch_size = 10 + resumedb = os.path.join(options.local_folder, + "%s-restored.sqlite" % options.email) + if options.noresume: + try: + os.remove(resumedb) + except OSError: + pass + except IOError: + pass + sqlcur.execute('ATTACH ? as resume', (resumedb,)) + sqlcur.executescript('''CREATE TABLE IF NOT EXISTS resume.restored_messages + (message_num INTEGER PRIMARY KEY); + CREATE TEMP TABLE skip_messages (message_num INTEGER \ + PRIMARY KEY);''') + sqlcur.execute('''INSERT INTO skip_messages SELECT message_num from \ + restored_messages''') + sqlcur.execute('''SELECT message_num, message_internaldate, \ + message_filename FROM messages + WHERE message_num NOT IN skip_messages ORDER BY \ + message_internaldate DESC''') # All messages + + restore_serv = gmail.users().messages() + if options.fast_restore: + restore_func = 'insert' + restore_params = {'internalDateSource': 'dateHeader'} + else: + restore_func = 'import_' + restore_params = {'neverMarkSpam': True} + restore_method = getattr(restore_serv, restore_func) + messages_to_restore_results = sqlcur.fetchall() + restore_count = len(messages_to_restore_results) + current = 0 + gbatch = googleapiclient.http.BatchHttpRequest() + max_batch_bytes = 8 * 1024 * 1024 + current_batch_bytes = 5000 # accounts for metadata + largest_in_batch = 0 + for x in messages_to_restore_results: + current += 1 + message_filename = x[2] + message_num = x[0] + if not os.path.isfile(os.path.join(options.local_folder, + message_filename)): + print('WARNING! file %s does not exist for message %s' + % (os.path.join(options.local_folder, message_filename), + message_num)) + print(' this message will be skipped.') + continue + f = open(os.path.join(options.local_folder, message_filename), 'rb') + full_message = f.read() + f.close() + #full_message = full_message.replace('\x00', '') # No NULL chars + labels = [] + if not options.strip_labels: + sqlcur.execute('SELECT DISTINCT label FROM labels WHERE message_num \ + = ?', (message_num,)) + labels_results = sqlcur.fetchall() + for l in labels_results: + labels.append(l[0]) + if options.label_restored: + for restore_label in options.label_restored: + labels.append(restore_label) + labelIds = labelsToLabelIds(labels) + body = {'labelIds': labelIds} + b64_message_size = (len(full_message)/3) * 4 + if b64_message_size > 1 * 1024 * 1024: + # don't batch/raw >1mb messages, just do single + rewrite_line('restoring single large message (%s/%s)' % + (current, restore_count)) + media_body = googleapiclient.http.MediaInMemoryUpload(full_message, + mimetype='message/rfc822') + response = callGAPI(service=restore_serv, function=restore_func, + userId='me', media_body=media_body, body=body, + deleted=options.vault, **restore_params) + restored_message(request_id=str(message_num), response=response, + exception=None) + rewrite_line('restored single large message (%s/%s)' % (current, + restore_count)) + continue + if b64_message_size > largest_in_batch: + largest_in_batch = b64_message_size + raw_message = base64.urlsafe_b64encode(full_message).decode('utf-8') + body['raw'] = raw_message + current_batch_bytes += len(raw_message) + for labelId in labelIds: + current_batch_bytes += len(labelId) + if len(gbatch._order) > 0 and current_batch_bytes > max_batch_bytes: + # this message would put us over max, execute current batch first + rewrite_line("restoring %s messages (%s/%s)" % (len(gbatch._order), + current, restore_count)) + gbatch.execute() + gbatch = googleapiclient.http.BatchHttpRequest() + sqlconn.commit() + rewrite_line("restored %s messages (%s/%s)" % (len(gbatch._order), + current, restore_count)) + current_batch_bytes = 5000 + largest_in_batch = 0 + gbatch.add(restore_method(userId='me', + body=body, fields='id', deleted=options.vault, + **restore_params), callback=restored_message, + request_id=str(message_num)) + if len(gbatch._order) == options.batch_size: + rewrite_line("restoring %s messages (%s/%s)" % (len(gbatch._order), + current, restore_count)) + gbatch.execute() + gbatch = googleapiclient.http.BatchHttpRequest() + sqlconn.commit() + rewrite_line("restored %s messages (%s/%s)" % (len(gbatch._order), + current, restore_count)) + current_batch_bytes = 5000 + largest_in_batch = 0 + if len(gbatch._order) > 0: + rewrite_line("restoring %s messages (%s/%s)" % (len(gbatch._order), + current, restore_count)) + gbatch.execute() + sqlconn.commit() + rewrite_line("restored %s messages (%s/%s)" % (len(gbatch._order), + current, restore_count)) + print("\n") + sqlconn.execute('DETACH resume') + sqlconn.commit() + + # RESTORE-MBOX # + elif options.action == 'restore-mbox': + if options.batch_size == 0: + options.batch_size = 10 + resumedb = os.path.join(options.local_folder, + "%s-restored.sqlite" % options.email) + if options.noresume: + try: + os.remove(resumedb) + except OSError: + pass + except IOError: + pass + sqlcur.execute('ATTACH ? as mbox_resume', (resumedb,)) + sqlcur.executescript('''CREATE TABLE + IF NOT EXISTS mbox_resume.restored_messages + (message_num TEXT PRIMARY KEY)''') + sqlcur.execute('SELECT message_num FROM mbox_resume.restored_messages') + messages_to_skip_results = sqlcur.fetchall() + messages_to_skip = [] + for a_message in messages_to_skip_results: + messages_to_skip.append(a_message[0]) + if os.name == 'windows' or os.name == 'nt': + divider = '\\' + else: + divider = '/' + current_batch_bytes = 5000 + gbatch = googleapiclient.http.BatchHttpRequest() + restore_serv = gmail.users().messages() + if options.fast_restore: + restore_func = 'insert' + restore_params = {'internalDateSource': 'dateHeader'} + else: + restore_func = 'import_' + restore_params = {'neverMarkSpam': True} + restore_method = getattr(restore_serv, restore_func) + max_batch_bytes = 8 * 1024 * 1024 + for path, subdirs, files in os.walk(options.local_folder): + for filename in files: + if filename[-4:].lower() != '.mbx' and \ + filename[-5:].lower() != '.mbox': + continue + file_path = '%s%s%s' % (path, divider, filename) + mbox = mailbox.mbox(file_path) + restore_count = len(list(mbox.items())) + current = 0 + print("\nRestoring from %s" % file_path) + for message in mbox: + current += 1 + message_marker = '%s-%s' % (file_path, current) + if message_marker in messages_to_skip: + continue + restart_line() + labels = message['X-Gmail-Labels'] + if labels != None and labels != '' and not options.strip_labels: + mybytes, encoding = email.header.decode_header(labels)[0] + if encoding != None: + try: + labels = mybytes.decode(encoding) + except UnicodeDecodeError: + pass + else: + labels = labels.decode('string-escape') + labels = labels.split(',') + else: + labels = [] + if options.label_restored: + for restore_label in options.label_restored: + labels.append(restore_label) + labelIds = labelsToLabelIds(labels) + del message['X-Gmail-Labels'] + del message['X-GM-THRID'] + rewrite_line(" message %s of %s" % (current, restore_count)) + full_message = message.as_string() + body = {'labelIds': labelIds} + b64_message_size = (len(full_message)/3) * 4 + if b64_message_size > 1 * 1024 * 1024: + # don't batch/raw >1mb messages, just do single + rewrite_line(' restoring single large message (%s/%s)' % + (current, restore_count)) + media_body = googleapiclient.http.MediaInMemoryUpload(full_message, + mimetype='message/rfc822') + response = callGAPI(service=restore_serv, function=restore_func, + userId='me', media_body=media_body, body=body, + deleted=options.vault, **restore_params) + restored_message(request_id=str(message_marker), response=response, + exception=None) + rewrite_line(' restored single large message (%s/%s)' % + (current, restore_count)) + continue + raw_message = base64.urlsafe_b64encode(full_message) + body['raw'] = raw_message + current_batch_bytes += len(raw_message) + for labelId in labelIds: + current_batch_bytes += len(labelId) + if len(gbatch._order) > 0 and current_batch_bytes > max_batch_bytes: + # this message would put us over max, execute current batch first + rewrite_line("restoring %s messages (%s/%s)" % + (len(gbatch._order), current, restore_count)) + gbatch.execute() + gbatch = googleapiclient.http.BatchHttpRequest() + sqlconn.commit() + rewrite_line("restored %s messages (%s/%s)" % + (len(gbatch._order), current, restore_count)) + current_batch_bytes = 5000 + largest_in_batch = 0 + gbatch.add(restore_method(userId='me', + body=body, fields='id', + deleted=options.vault, **restore_params), + callback=restored_message, + request_id=str(message_marker)) + if len(gbatch._order) == options.batch_size: + rewrite_line("restoring %s messages (%s/%s)" % + (len(gbatch._order), current, restore_count)) + gbatch.execute() + gbatch = googleapiclient.http.BatchHttpRequest() + sqlconn.commit() + rewrite_line("restored %s messages (%s/%s)" % + (len(gbatch._order), current, restore_count)) + current_batch_bytes = 5000 + largest_in_batch = 0 + if len(gbatch._order) > 0: + rewrite_line("restoring %s messages (%s/%s)" % + (len(gbatch._order), current, restore_count)) + gbatch.execute() + sqlconn.commit() + rewrite_line("restoring %s messages (%s/%s)" % + (len(gbatch._order), current, restore_count)) + sqlconn.execute('DETACH mbox_resume') + sqlconn.commit() + + # RESTORE-GROUP # + elif options.action == 'restore-group': + if not options.service_account: # 3-Legged OAuth + gmig = buildGAPIObject('groupsmigration') + else: + gmig = buildGAPIServiceObject('groupsmigration') + resumedb = os.path.join(options.local_folder, + "%s-restored.sqlite" % options.email) + if options.noresume: + try: + os.remove(resumedb) + except OSError: + pass + except IOError: + pass + sqlcur.execute('ATTACH ? as resume', (resumedb,)) + sqlcur.executescript('''CREATE TABLE IF NOT EXISTS resume.restored_messages + (message_num INTEGER PRIMARY KEY); + CREATE TEMP TABLE skip_messages (message_num INTEGER PRIMARY KEY);''') + sqlcur.execute('''INSERT INTO skip_messages SELECT message_num + FROM restored_messages''') + sqlcur.execute('''SELECT message_num, message_internaldate, + message_filename FROM messages + WHERE message_num NOT IN skip_messages + ORDER BY message_internaldate DESC''') + messages_to_restore_results = sqlcur.fetchall() + restore_count = len(messages_to_restore_results) + current = 0 + for x in messages_to_restore_results: + current += 1 + rewrite_line("restoring message %s of %s from %s" % + (current, restore_count, x[1])) + message_num = x[0] + message_filename = x[2] + if not os.path.isfile(os.path.join(options.local_folder, + message_filename)): + print('WARNING! file %s does not exist for message %s' % + (os.path.join(options.local_folder, message_filename), message_num)) + print(' this message will be skipped.') + continue + f = open(os.path.join(options.local_folder, message_filename), 'rb') + full_message = f.read() + f.close() + media = googleapiclient.http.MediaFileUpload( + os.path.join(options.local_folder, message_filename), + mimetype='message/rfc822') + try: + callGAPI(service=gmig.archive(), function='insert', + groupId=options.email, media_body=media) + except googleapiclient.errors.MediaUploadSizeError as e: + print('\n ERROR: Message is to large for groups (16mb limit). \ + Skipping...') + continue + sqlconn.execute( + 'INSERT OR IGNORE INTO restored_messages (message_num) VALUES (?)', + (message_num,)) + sqlconn.commit() + sqlconn.execute('DETACH resume') + sqlconn.commit() + + # COUNT + elif options.action == 'count': + if options.batch_size == 0: + options.batch_size = 100 + messages_to_process = callGAPIpages(service=gmail.users().messages(), + function='list', items='messages', + userId='me', includeSpamTrash=options.spamtrash, q=options.gmail_search, + fields='nextPageToken,messages/id') + estimate_count = len(messages_to_process) + print("%s,%s" % (options.email, estimate_count)) + + # PURGE # + elif options.action == 'purge': + if options.batch_size == 0: + options.batch_size = 20 + page_message = 'Got %%total_items%% Message IDs' + messages_to_process = callGAPIpages(service=gmail.users().messages(), + function='list', items='messages', page_message=page_message, + userId='me', includeSpamTrash=True, q=options.gmail_search, + fields='nextPageToken,messages/id') + purge_count = len(messages_to_process) + purged_messages = 0 + gbatch = googleapiclient.http.BatchHttpRequest() + for a_message in messages_to_process: + gbatch.add(gmail.users().messages().delete(userId='me', + id=a_message['id']), callback=purged_message) + purged_messages += 1 + if len(gbatch._order) == options.batch_size: + gbatch.execute() + gbatch = googleapiclient.http.BatchHttpRequest() + rewrite_line("purged %s of %s messages" % + (purged_messages, purge_count)) + if len(gbatch._order) > 0: + gbatch.execute() + rewrite_line("purged %s of %s messages" % (purged_messages, purge_count)) + print("\n") + + # PURGE-LABELS # + elif options.action == 'purge-labels': + pattern = options.gmail_search + if pattern == None: + pattern = '.*' + pattern = re.compile(pattern) + existing_labels = callGAPI(service=gmail.users().labels(), function='list', + userId='me', fields='labels(id,name,type)') + for label_result in existing_labels['labels']: + if label_result['type'] == 'system' or not \ + pattern.search(label_result['name']): + continue + rewrite_line('Deleting label %s' % label_result['name']) + callGAPI(service=gmail.users().labels(), function='delete', + userId='me', id=label_result['id']) + print('\n') + + # QUOTA # + elif options.action == 'quota': + if not options.service_account: # 3-Legged OAuth + drive = buildGAPIObject('drive') + else: + drive = buildGAPIServiceObject('drive') + quota_results = callGAPI(service=drive.about(), function='get', + fields='quotaBytesTotal,quotaBytesUsedInTrash,quotaBytesUsedAggregate,qu\ + otaBytesByService,quotaType') + for key in quota_results: + if key == 'quotaBytesByService': + print('Service Usage:') + for service in quota_results[key]: + myval = int(service['bytesUsed']) + myval = bytes_to_larger(myval) + service_name = '%s%s' % (service['serviceName'][0], + service['serviceName'][1:].lower()) + print(' %s: %s' % (service_name, myval)) + continue + myval = quota_results[key] + mysize = '' + if myval.isdigit(): + myval = bytes_to_larger(myval) + print('%s: %s' % (key, myval)) + + # REVOKE + elif options.action == 'revoke': + if options.service_account: + print('ERROR: --action revoke does not work with --service-account') + sys.exit(5) + oauth2file = getProgPath()+'%s.cfg' % options.email + storage = oauth2client.file.Storage(oauth2file) + credentials = storage.get() + try: + credentials.revoke_uri = oauth2client.GOOGLE_REVOKE_URI + except AttributeError: + print('Error: Authorization doesn\'t exist') + sys.exit(1) + certFile = getProgPath()+'cacert.pem' + disable_ssl_certificate_validation = False + if os.path.isfile(getProgPath()+'noverifyssl.txt'): + disable_ssl_certificate_validation = True + http = httplib2.Http(ca_certs=certFile, + disable_ssl_certificate_validation=disable_ssl_certificate_validation) + if os.path.isfile(getProgPath()+'debug.gam'): + httplib2.debuglevel = 4 + sys.stdout.write('This authorizaton token will self-destruct in 3...') + sys.stdout.flush() + time.sleep(1) + sys.stdout.write('2...') + sys.stdout.flush() + time.sleep(1) + sys.stdout.write('1...') + sys.stdout.flush() + time.sleep(1) + sys.stdout.write('boom!\n') + sys.stdout.flush() + try: + credentials.revoke(http) + except oauth2client.client.TokenRevokeError: + print('Error') + os.remove(oauth2file) + + # ESTIMATE # + elif options.action == 'estimate': + if options.batch_size == 0: + options.batch_size = 100 + page_message = 'Got %%total_items%% Message IDs' + messages_to_process = callGAPIpages(service=gmail.users().messages(), + function='list', items='messages', page_message=page_message, + userId='me', includeSpamTrash=options.spamtrash, q=options.gmail_search, + fields='nextPageToken,messages/id') + estimate_path = options.local_folder + if not os.path.isdir(estimate_path): + os.mkdir(estimate_path) + messages_to_estimate = [] + #Determine which messages from the search we haven't processed before. + print("GYB needs to examine %s messages" % len(messages_to_process)) + for message_num in messages_to_process: + if not newDB and message_is_backed_up(message_num['id'], sqlcur, + sqlconn, options.local_folder): + pass + else: + messages_to_estimate.append(message_num['id']) + print("GYB already has a backup of %s messages" % + (len(messages_to_process) - len(messages_to_estimate))) + estimate_count = len(messages_to_estimate) + print("GYB needs to estimate %s messages" % estimate_count) + estimated_messages = 0 + gbatch = googleapiclient.http.BatchHttpRequest() + global message_size_estimate + message_size_estimate = 0 + for a_message in messages_to_estimate: + gbatch.add(gmail.users().messages().get(userId='me', + id=a_message, format='minimal', + fields='sizeEstimate'), + callback=estimate_message) + estimated_messages += 1 + if len(gbatch._order) == options.batch_size: + gbatch.execute() + gbatch = googleapiclient.http.BatchHttpRequest() + sqlconn.commit() + rewrite_line("Estimated size %s %s/%s messages" % + (bytes_to_larger(message_size_estimate), estimated_messages, + estimate_count)) + if len(gbatch._order) > 0: + gbatch.execute() + sqlconn.commit() + rewrite_line("Estimated size %s %s/%s messages" % + (bytes_to_larger(message_size_estimate), estimated_messages, + estimate_count)) + print('\n') + +if __name__ == '__main__': + doGYBCheckForUpdates() + try: + main(sys.argv[1:]) + except KeyboardInterrupt: + try: + sqlconn.commit() + sqlconn.close() + print() + except NameError: + pass + sys.exit(4) \ No newline at end of file diff --git a/gyb.spec b/gyb.spec index 4866e8c..b24dc42 100644 --- a/gyb.spec +++ b/gyb.spec @@ -18,4 +18,4 @@ exe = EXE(pyz, debug=False, strip=None, upx=True, - console=True ) + console=True ) \ No newline at end of file diff --git a/httplib2/__init__.py b/httplib2/__init__.py index d1212b5..260fa6b 100644 --- a/httplib2/__init__.py +++ b/httplib2/__init__.py @@ -1,13 +1,14 @@ -from __future__ import generators + """ httplib2 A caching http interface that supports ETags and gzip to conserve bandwidth. -Requires Python 2.3 or later +Requires Python 3.0 or later Changelog: +2009-05-28, Pilgrim: ported to Python 3 2007-08-18, Rick: Modified so it's able to use a socks proxy if needed. """ @@ -15,27 +16,27 @@ __author__ = "Joe Gregorio (joe@bitworking.org)" __copyright__ = "Copyright 2006, Joe Gregorio" __contributors__ = ["Thomas Broyer (t.broyer@ltgt.net)", - "James Antill", - "Xavier Verges Farrero", - "Jonathan Feinberg", - "Blair Zajac", - "Sam Ruby", - "Louis Nyffenegger"] + "James Antill", + "Xavier Verges Farrero", + "Jonathan Feinberg", + "Blair Zajac", + "Sam Ruby", + "Louis Nyffenegger", + "Mark Pilgrim"] __license__ = "MIT" -__version__ = "0.9" +__version__ = "0.9.1" import re import sys import email -import email.Utils -import email.Message -import email.FeedParser -import StringIO +import email.utils +import email.message +import email.feedparser +import io import gzip import zlib -import httplib -import urlparse -import urllib +import http.client +import urllib.parse import base64 import os import copy @@ -43,71 +44,30 @@ import time import random import errno -try: - from hashlib import sha1 as _sha, md5 as _md5 -except ImportError: - # prior to Python 2.5, these were separate modules - import sha - import md5 - _sha = sha.new - _md5 = md5.new +from hashlib import sha1 as _sha, md5 as _md5 import hmac from gettext import gettext as _ import socket +import ssl +_ssl_wrap_socket = ssl.wrap_socket try: - from httplib2 import socks + import socks except ImportError: - try: - import socks - except (ImportError, AttributeError): - socks = None + socks = None -# Build the appropriate socket wrapper for ssl -try: - import ssl # python 2.6 - ssl_SSLError = ssl.SSLError - def _ssl_wrap_socket(sock, key_file, cert_file, - disable_validation, ca_certs): - if disable_validation: - cert_reqs = ssl.CERT_NONE - else: - cert_reqs = ssl.CERT_REQUIRED - # We should be specifying SSL version 3 or TLS v1, but the ssl module - # doesn't expose the necessary knobs. So we need to go with the default - # of SSLv23. - return ssl.wrap_socket(sock, keyfile=key_file, certfile=cert_file, - cert_reqs=cert_reqs, ca_certs=ca_certs) -except (AttributeError, ImportError): - ssl_SSLError = None - def _ssl_wrap_socket(sock, key_file, cert_file, - disable_validation, ca_certs): - if not disable_validation: - raise CertificateValidationUnsupported( - "SSL certificate validation is not supported without " - "the ssl module installed. To avoid this error, install " - "the ssl module, or explicity disable validation.") - ssl_sock = socket.ssl(sock, key_file, cert_file) - return httplib.FakeSocket(sock, ssl_sock) - - -if sys.version_info >= (2,3): - from iri2uri import iri2uri -else: - def iri2uri(uri): - return uri - -def has_timeout(timeout): # python 2.6 +from .iri2uri import iri2uri + +def has_timeout(timeout): if hasattr(socket, '_GLOBAL_DEFAULT_TIMEOUT'): return (timeout is not None and timeout is not socket._GLOBAL_DEFAULT_TIMEOUT) return (timeout is not None) -__all__ = [ - 'Http', 'Response', 'ProxyInfo', 'HttpLib2Error', 'RedirectMissingLocation', - 'RedirectLimit', 'FailedToDecompressContent', - 'UnimplementedDigestAuthOptionError', - 'UnimplementedHmacDigestAuthOptionError', - 'debuglevel', 'ProxiesUnavailableError'] +__all__ = ['Http', 'Response', 'ProxyInfo', 'HttpLib2Error', + 'RedirectMissingLocation', 'RedirectLimit', + 'FailedToDecompressContent', 'UnimplementedDigestAuthOptionError', + 'UnimplementedHmacDigestAuthOptionError', + 'debuglevel', 'RETRIES'] # The httplib debug level, set to a non-zero value to get debug output @@ -116,22 +76,6 @@ def has_timeout(timeout): # python 2.6 # A request will be tried 'RETRIES' times if it fails at the socket/connection level. RETRIES = 2 -# Python 2.3 support -if sys.version_info < (2,4): - def sorted(seq): - seq.sort() - return seq - -# Python 2.3 support -def HTTPResponse__getheaders(self): - """Return list of (header, value) tuples.""" - if self.msg is None: - raise httplib.ResponseNotReady() - return self.msg.items() - -if not hasattr(httplib.HTTPResponse, 'getheaders'): - httplib.HTTPResponse.getheaders = HTTPResponse__getheaders - # All exceptions raised here derive from HttpLib2Error class HttpLib2Error(Exception): pass @@ -152,15 +96,7 @@ class UnimplementedHmacDigestAuthOptionError(HttpLib2ErrorWithResponse): pass class MalformedHeader(HttpLib2Error): pass class RelativeURIError(HttpLib2Error): pass class ServerNotFoundError(HttpLib2Error): pass -class ProxiesUnavailableError(HttpLib2Error): pass -class CertificateValidationUnsupported(HttpLib2Error): pass -class SSLHandshakeError(HttpLib2Error): pass -class NotSupportedOnThisPlatform(HttpLib2Error): pass -class CertificateHostnameMismatch(SSLHandshakeError): - def __init__(self, desc, host, cert): - HttpLib2Error.__init__(self, desc) - self.host = host - self.cert = cert +class CertificateValidationUnsupportedInPython31(HttpLib2Error): pass # Open Items: # ----------- @@ -184,23 +120,17 @@ def __init__(self, desc, host, cert): # requesting that URI again. DEFAULT_MAX_REDIRECTS = 5 -try: - # Users can optionally provide a module that tells us where the CA_CERTS - # are located. - import ca_certs_locater - CA_CERTS = ca_certs_locater.get() -except ImportError: - # Default CA certificates file bundled with httplib2. - CA_CERTS = os.path.join( - os.path.dirname(os.path.abspath(__file__ )), "cacerts.txt") - # Which headers are hop-by-hop headers by default HOP_BY_HOP = ['connection', 'keep-alive', 'proxy-authenticate', 'proxy-authorization', 'te', 'trailers', 'transfer-encoding', 'upgrade'] +# Default CA certificates file bundled with httplib2. +CA_CERTS = os.path.join( + os.path.dirname(os.path.abspath(__file__ )), "cacerts.txt") + def _get_end2end_headers(response): hopbyhop = list(HOP_BY_HOP) hopbyhop.extend([x.strip() for x in response.get('connection', '').split(',')]) - return [header for header in response.keys() if header not in hopbyhop] + return [header for header in list(response.keys()) if header not in hopbyhop] URI = re.compile(r"^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?") @@ -229,8 +159,9 @@ def urlnorm(uri): # Cache filename construction (original borrowed from Venus http://intertwingly.net/code/venus/) -re_url_scheme = re.compile(r'^\w+://') -re_slash = re.compile(r'[?/:|]+') +re_url_scheme = re.compile(br'^\w+://') +re_url_scheme_s = re.compile(r'^\w+://') +re_slash = re.compile(br'[?/:|]+') def safename(filename): """Return a filename suitable for the cache. @@ -240,32 +171,37 @@ def safename(filename): """ try: - if re_url_scheme.match(filename): - if isinstance(filename,str): + if re_url_scheme_s.match(filename): + if isinstance(filename,bytes): filename = filename.decode('utf-8') filename = filename.encode('idna') else: filename = filename.encode('idna') except UnicodeError: pass - if isinstance(filename,unicode): + if isinstance(filename,str): filename=filename.encode('utf-8') - filemd5 = _md5(filename).hexdigest() - filename = re_url_scheme.sub("", filename) - filename = re_slash.sub(",", filename) + filemd5 = _md5(filename).hexdigest().encode('utf-8') + filename = re_url_scheme.sub(b"", filename) + filename = re_slash.sub(b",", filename) # limit length of filename if len(filename)>200: filename=filename[:200] - return ",".join((filename, filemd5)) + return b",".join((filename, filemd5)).decode('utf-8') NORMALIZE_SPACE = re.compile(r'(?:\r\n)?[ \t]+') def _normalize_headers(headers): - return dict([ (key.lower(), NORMALIZE_SPACE.sub(value, ' ').strip()) for (key, value) in headers.iteritems()]) + return dict([ (_convert_byte_str(key).lower(), NORMALIZE_SPACE.sub(_convert_byte_str(value), ' ').strip()) for (key, value) in headers.items()]) +def _convert_byte_str(s): + if not isinstance(s, str): + return str(s, 'utf-8') + return s + def _parse_cache_control(headers): retval = {} - if headers.has_key('cache-control'): + if 'cache-control' in headers: parts = headers['cache-control'].split(',') parts_with_args = [tuple([x.strip().lower() for x in part.split("=", 1)]) for part in parts if -1 != part.find("=")] parts_wo_args = [(name.strip().lower(), 1) for name in parts if -1 == name.find("=")] @@ -290,9 +226,8 @@ def _parse_www_authenticate(headers, headername='www-authenticate'): """Returns a dictionary of dictionaries, one dict per auth_scheme.""" retval = {} - if headers.has_key(headername): + if headername in headers: try: - authenticate = headers[headername].strip() www_auth = USE_WWW_AUTH_STRICT_PARSING and WWW_AUTH_STRICT or WWW_AUTH_RELAXED while authenticate: @@ -312,7 +247,6 @@ def _parse_www_authenticate(headers, headername='www-authenticate'): match = www_auth.search(the_rest) retval[auth_scheme.lower()] = auth_params authenticate = the_rest.strip() - except ValueError: raise MalformedHeader("WWW-Authenticate") return retval @@ -350,39 +284,39 @@ def _entry_disposition(response_headers, request_headers): cc = _parse_cache_control(request_headers) cc_response = _parse_cache_control(response_headers) - if request_headers.has_key('pragma') and request_headers['pragma'].lower().find('no-cache') != -1: + if 'pragma' in request_headers and request_headers['pragma'].lower().find('no-cache') != -1: retval = "TRANSPARENT" if 'cache-control' not in request_headers: request_headers['cache-control'] = 'no-cache' - elif cc.has_key('no-cache'): + elif 'no-cache' in cc: retval = "TRANSPARENT" - elif cc_response.has_key('no-cache'): + elif 'no-cache' in cc_response: retval = "STALE" - elif cc.has_key('only-if-cached'): + elif 'only-if-cached' in cc: retval = "FRESH" - elif response_headers.has_key('date'): - date = calendar.timegm(email.Utils.parsedate_tz(response_headers['date'])) + elif 'date' in response_headers: + date = calendar.timegm(email.utils.parsedate_tz(response_headers['date'])) now = time.time() current_age = max(0, now - date) - if cc_response.has_key('max-age'): + if 'max-age' in cc_response: try: freshness_lifetime = int(cc_response['max-age']) except ValueError: freshness_lifetime = 0 - elif response_headers.has_key('expires'): - expires = email.Utils.parsedate_tz(response_headers['expires']) + elif 'expires' in response_headers: + expires = email.utils.parsedate_tz(response_headers['expires']) if None == expires: freshness_lifetime = 0 else: freshness_lifetime = max(0, calendar.timegm(expires) - date) else: freshness_lifetime = 0 - if cc.has_key('max-age'): + if 'max-age' in cc: try: freshness_lifetime = int(cc['max-age']) except ValueError: freshness_lifetime = 0 - if cc.has_key('min-fresh'): + if 'min-fresh' in cc: try: min_fresh = int(cc['min-fresh']) except ValueError: @@ -398,7 +332,7 @@ def _decompressContent(response, new_content): encoding = response.get('content-encoding', None) if encoding in ['gzip', 'deflate']: if encoding == 'gzip': - content = gzip.GzipFile(fileobj=StringIO.StringIO(new_content)).read() + content = gzip.GzipFile(fileobj=io.BytesIO(new_content)).read() if encoding == 'deflate': content = zlib.decompress(content) response['content-length'] = str(len(content)) @@ -410,15 +344,32 @@ def _decompressContent(response, new_content): raise FailedToDecompressContent(_("Content purported to be compressed with %s but failed to decompress.") % response.get('content-encoding'), response, content) return content +def _bind_write_headers(msg): + from email.header import Header + def _write_headers(self): + # Self refers to the Generator object + for h, v in msg.items(): + print('%s:' % h, end=' ', file=self._fp) + if isinstance(v, Header): + print(v.encode(maxlinelen=self._maxheaderlen), file=self._fp) + else: + # Header's got lots of smarts, so use it. + header = Header(v, maxlinelen=self._maxheaderlen, charset='utf-8', + header_name=h) + print(header.encode(), file=self._fp) + # A blank line always separates headers from body + print(file=self._fp) + return _write_headers + def _updateCache(request_headers, response_headers, content, cache, cachekey): if cachekey: cc = _parse_cache_control(request_headers) cc_response = _parse_cache_control(response_headers) - if cc.has_key('no-store') or cc_response.has_key('no-store'): + if 'no-store' in cc or 'no-store' in cc_response: cache.delete(cachekey) else: - info = email.Message.Message() - for key, value in response_headers.iteritems(): + info = email.message.Message() + for key, value in response_headers.items(): if key not in ['status','content-encoding','transfer-encoding']: info[key] = value @@ -440,19 +391,23 @@ def _updateCache(request_headers, response_headers, content, cache, cachekey): status_header = 'status: %d\r\n' % status - header_str = info.as_string() + try: + header_str = info.as_string() + except UnicodeEncodeError: + setattr(info, '_write_headers', _bind_write_headers(info)) + header_str = info.as_string() header_str = re.sub("\r(?!\n)|(? 0: - print "connect: (%s, %s) ************" % (self.host, self.port) - if use_proxy: - print "proxy: %s ************" % str((proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass)) - - self.sock.connect((self.host, self.port) + sa[2:]) - except socket.error, msg: - if self.debuglevel > 0: - print "connect fail: (%s, %s)" % (self.host, self.port) - if use_proxy: - print "proxy: %s" % str((proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass)) - if self.sock: - self.sock.close() - self.sock = None - continue - break - if not self.sock: - raise socket.error, msg - -class HTTPSConnectionWithTimeout(httplib.HTTPSConnection): +class HTTPSConnectionWithTimeout(http.client.HTTPSConnection): """ This class allows communication via SSL. @@ -921,209 +827,34 @@ class HTTPSConnectionWithTimeout(httplib.HTTPSConnection): the docs of socket.setdefaulttimeout(): http://docs.python.org/library/socket.html#socket.setdefaulttimeout """ + def __init__(self, host, port=None, key_file=None, cert_file=None, - strict=None, timeout=None, proxy_info=None, + timeout=None, proxy_info=None, ca_certs=None, disable_ssl_certificate_validation=False): - httplib.HTTPSConnection.__init__(self, host, port=port, - key_file=key_file, - cert_file=cert_file, strict=strict) - self.timeout = timeout self.proxy_info = proxy_info + context = None if ca_certs is None: ca_certs = CA_CERTS - self.ca_certs = ca_certs - self.disable_ssl_certificate_validation = \ - disable_ssl_certificate_validation + if (cert_file or ca_certs) and not disable_ssl_certificate_validation: + if not hasattr(ssl, 'SSLContext'): + raise CertificateValidationUnsupportedInPython31() + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.verify_mode = ssl.CERT_REQUIRED + if cert_file: + context.load_cert_chain(cert_file, key_file) + if ca_certs: + context.load_verify_locations(ca_certs) + http.client.HTTPSConnection.__init__( + self, host, port=port, key_file=key_file, + cert_file=cert_file, timeout=timeout, context=context, + check_hostname=True) - # The following two methods were adapted from https_wrapper.py, released - # with the Google Appengine SDK at - # http://googleappengine.googlecode.com/svn-history/r136/trunk/python/google/appengine/tools/https_wrapper.py - # under the following license: - # - # Copyright 2007 Google Inc. - # - # Licensed under the Apache License, Version 2.0 (the "License"); - # you may not use this file except in compliance with the License. - # You may obtain a copy of the License at - # - # http://www.apache.org/licenses/LICENSE-2.0 - # - # Unless required by applicable law or agreed to in writing, software - # distributed under the License is distributed on an "AS IS" BASIS, - # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - # See the License for the specific language governing permissions and - # limitations under the License. - # - - def _GetValidHostsForCert(self, cert): - """Returns a list of valid host globs for an SSL certificate. - - Args: - cert: A dictionary representing an SSL certificate. - Returns: - list: A list of valid host globs. - """ - if 'subjectAltName' in cert: - return [x[1] for x in cert['subjectAltName'] - if x[0].lower() == 'dns'] - else: - return [x[0][1] for x in cert['subject'] - if x[0][0].lower() == 'commonname'] - - def _ValidateCertificateHostname(self, cert, hostname): - """Validates that a given hostname is valid for an SSL certificate. - - Args: - cert: A dictionary representing an SSL certificate. - hostname: The hostname to test. - Returns: - bool: Whether or not the hostname is valid for this certificate. - """ - hosts = self._GetValidHostsForCert(cert) - for host in hosts: - host_re = host.replace('.', '\.').replace('*', '[^.]*') - if re.search('^%s$' % (host_re,), hostname, re.I): - return True - return False - - def connect(self): - "Connect to a host on a given (SSL) port." - - msg = "getaddrinfo returns an empty list" - if self.proxy_info and self.proxy_info.isgood(): - use_proxy = True - proxy_type, proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass = self.proxy_info.astuple() - else: - use_proxy = False - if use_proxy and proxy_rdns: - host = proxy_host - port = proxy_port - else: - host = self.host - port = self.port - - address_info = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM) - for family, socktype, proto, canonname, sockaddr in address_info: - try: - if use_proxy: - sock = socks.socksocket(family, socktype, proto) - - sock.setproxy(proxy_type, proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass) - else: - sock = socket.socket(family, socktype, proto) - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - if has_timeout(self.timeout): - sock.settimeout(self.timeout) - sock.connect((self.host, self.port)) - self.sock =_ssl_wrap_socket( - sock, self.key_file, self.cert_file, - self.disable_ssl_certificate_validation, self.ca_certs) - if self.debuglevel > 0: - print "connect: (%s, %s)" % (self.host, self.port) - if use_proxy: - print "proxy: %s" % str((proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass)) - if not self.disable_ssl_certificate_validation: - cert = self.sock.getpeercert() - hostname = self.host.split(':', 0)[0] - if not self._ValidateCertificateHostname(cert, hostname): - raise CertificateHostnameMismatch( - 'Server presented certificate that does not match ' - 'host %s: %s' % (hostname, cert), hostname, cert) - except ssl_SSLError, e: - if sock: - sock.close() - if self.sock: - self.sock.close() - self.sock = None - # Unfortunately the ssl module doesn't seem to provide any way - # to get at more detailed error information, in particular - # whether the error is due to certificate validation or - # something else (such as SSL protocol mismatch). - if e.errno == ssl.SSL_ERROR_SSL: - raise SSLHandshakeError(e) - else: - raise - except (socket.timeout, socket.gaierror): - raise - except socket.error, msg: - if self.debuglevel > 0: - print "connect fail: (%s, %s)" % (self.host, self.port) - if use_proxy: - print "proxy: %s" % str((proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass)) - if self.sock: - self.sock.close() - self.sock = None - continue - break - if not self.sock: - raise socket.error, msg SCHEME_TO_CONNECTION = { 'http': HTTPConnectionWithTimeout, - 'https': HTTPSConnectionWithTimeout + 'https': HTTPSConnectionWithTimeout, } -# Use a different connection object for Google App Engine -try: - try: - from google.appengine.api import apiproxy_stub_map - if apiproxy_stub_map.apiproxy.GetStub('urlfetch') is None: - raise ImportError # Bail out; we're not actually running on App Engine. - from google.appengine.api.urlfetch import fetch - from google.appengine.api.urlfetch import InvalidURLError - except (ImportError, AttributeError): - from google3.apphosting.api import apiproxy_stub_map - if apiproxy_stub_map.apiproxy.GetStub('urlfetch') is None: - raise ImportError # Bail out; we're not actually running on App Engine. - from google3.apphosting.api.urlfetch import fetch - from google3.apphosting.api.urlfetch import InvalidURLError - - def _new_fixed_fetch(validate_certificate): - def fixed_fetch(url, payload=None, method="GET", headers={}, - allow_truncated=False, follow_redirects=True, - deadline=None): - if deadline is None: - deadline = socket.getdefaulttimeout() or 5 - return fetch(url, payload=payload, method=method, headers=headers, - allow_truncated=allow_truncated, - follow_redirects=follow_redirects, deadline=deadline, - validate_certificate=validate_certificate) - return fixed_fetch - - class AppEngineHttpConnection(httplib.HTTPConnection): - """Use httplib on App Engine, but compensate for its weirdness. - - The parameters key_file, cert_file, proxy_info, ca_certs, and - disable_ssl_certificate_validation are all dropped on the ground. - """ - def __init__(self, host, port=None, key_file=None, cert_file=None, - strict=None, timeout=None, proxy_info=None, ca_certs=None, - disable_ssl_certificate_validation=False): - httplib.HTTPConnection.__init__(self, host, port=port, - strict=strict, timeout=timeout) - - class AppEngineHttpsConnection(httplib.HTTPSConnection): - """Same as AppEngineHttpConnection, but for HTTPS URIs.""" - def __init__(self, host, port=None, key_file=None, cert_file=None, - strict=None, timeout=None, proxy_info=None, ca_certs=None, - disable_ssl_certificate_validation=False): - httplib.HTTPSConnection.__init__(self, host, port=port, - key_file=key_file, - cert_file=cert_file, strict=strict, - timeout=timeout) - self._fetch = _new_fixed_fetch( - not disable_ssl_certificate_validation) - - # Update the connection classes to use the Googel App Engine specific ones. - SCHEME_TO_CONNECTION = { - 'http': AppEngineHttpConnection, - 'https': AppEngineHttpsConnection - } -except (ImportError, AttributeError): - pass - - class Http(object): """An HTTP client that handles: @@ -1153,7 +884,7 @@ def __init__(self, cache=None, timeout=None, `proxy_info` may be: - a callable that takes the http scheme ('http' or 'https') and returns a ProxyInfo instance per request. By default, uses - proxy_nfo_from_environment. + proxy_info_from_environment. - a ProxyInfo instance (static proxy config). - None (proxy disabled). @@ -1163,17 +894,16 @@ def __init__(self, cache=None, timeout=None, If disable_ssl_certificate_validation is true, SSL cert validation will not be performed. - """ +""" self.proxy_info = proxy_info self.ca_certs = ca_certs self.disable_ssl_certificate_validation = \ disable_ssl_certificate_validation - # Map domain name to an httplib connection self.connections = {} # The location of the cache, for now a directory # where cached responses are held. - if cache and isinstance(cache, basestring): + if cache and isinstance(cache, str): self.cache = FileCache(cache) else: self.cache = cache @@ -1228,7 +958,7 @@ def _auth_from_challenge(self, host, request_uri, headers, response, content): challenges = _parse_www_authenticate(response, 'www-authenticate') for cred in self.credentials.iter(host): for scheme in AUTH_SCHEME_ORDER: - if challenges.has_key(scheme): + if scheme in challenges: yield AUTH_SCHEME_CLASSES[scheme](cred, host, request_uri, headers, response, content, self) def add_credentials(self, name, password, domain=""): @@ -1253,29 +983,21 @@ def _conn_request(self, conn, request_uri, method, body, headers): while i < RETRIES: i += 1 try: - if hasattr(conn, 'sock') and conn.sock is None: + if conn.sock is None: conn.connect() conn.request(method, request_uri, body, headers) except socket.timeout: + conn.close() raise except socket.gaierror: conn.close() raise ServerNotFoundError("Unable to find the server at %s" % conn.host) - except ssl_SSLError: - conn.close() - raise - except socket.error, e: - err = 0 - if hasattr(e, 'args'): - err = getattr(e, 'args')[0] - else: - err = e.errno - if err == errno.ECONNREFUSED: # Connection refused + except socket.error as e: + errno_ = (e.args[0].errno if isinstance(e.args[0], socket.error) else e.errno) + if errno_ == errno.ECONNREFUSED: # Connection refused raise - except httplib.HTTPException: - # Just because the server closed the connection doesn't apparently mean - # that the server didn't send a response. - if hasattr(conn, 'sock') and conn.sock is None: + except http.client.HTTPException: + if conn.sock is None: if i < RETRIES-1: conn.close() conn.connect() @@ -1287,9 +1009,12 @@ def _conn_request(self, conn, request_uri, method, body, headers): conn.close() conn.connect() continue + # Just because the server closed the connection doesn't apparently mean + # that the server didn't send a response. + pass try: response = conn.getresponse() - except httplib.BadStatusLine: + except (http.client.BadStatusLine, http.client.ResponseNotReady): # If we get a BadStatusLine on the first try then that means # the connection just went stale, so retry regardless of the # number of RETRIES set. @@ -1302,16 +1027,18 @@ def _conn_request(self, conn, request_uri, method, body, headers): else: conn.close() raise - except (socket.error, httplib.HTTPException): - if i < RETRIES-1: + except socket.timeout: + raise + except (socket.error, http.client.HTTPException): + conn.close() + if i == 0: conn.close() conn.connect() continue else: - conn.close() raise else: - content = "" + content = b"" if method == "HEAD": conn.close() else: @@ -1319,6 +1046,7 @@ def _conn_request(self, conn, request_uri, method, body, headers): response = Response(response) if method != "HEAD": content = _decompressContent(response, content) + break return (response, content) @@ -1354,44 +1082,43 @@ def _request(self, conn, host, absolute_uri, request_uri, method, body, headers, # Pick out the location header and basically start from the beginning # remembering first to strip the ETag header and decrement our 'depth' if redirections: - if not response.has_key('location') and response.status != 300: + if 'location' not in response and response.status != 300: raise RedirectMissingLocation( _("Redirected but the response is missing a Location: header."), response, content) # Fix-up relative redirects (which violate an RFC 2616 MUST) - if response.has_key('location'): + if 'location' in response: location = response['location'] (scheme, authority, path, query, fragment) = parse_uri(location) if authority == None: - response['location'] = urlparse.urljoin(absolute_uri, location) + response['location'] = urllib.parse.urljoin(absolute_uri, location) if response.status == 301 and method in ["GET", "HEAD"]: response['-x-permanent-redirect-url'] = response['location'] - if not response.has_key('content-location'): + if 'content-location' not in response: response['content-location'] = absolute_uri _updateCache(headers, response, content, self.cache, cachekey) - if headers.has_key('if-none-match'): + if 'if-none-match' in headers: del headers['if-none-match'] - if headers.has_key('if-modified-since'): + if 'if-modified-since' in headers: del headers['if-modified-since'] if 'authorization' in headers and not self.forward_authorization_headers: del headers['authorization'] - if response.has_key('location'): + if 'location' in response: location = response['location'] old_response = copy.deepcopy(response) - if not old_response.has_key('content-location'): + if 'content-location' not in old_response: old_response['content-location'] = absolute_uri redirect_method = method if response.status in [302, 303]: - redirect_method = "GET" - body = None + redirect_method = "GET" + body = None (response, content) = self.request( - location, method=redirect_method, - body=body, headers=headers, - redirections=redirections - 1) + location, method=redirect_method, body=body, + headers=headers, redirections=redirections - 1) response.previous = old_response else: - raise RedirectLimit("Redirected more times than rediection_limit allows.", response, content) + raise RedirectLimit("Redirected more times than redirection_limit allows.", response, content) elif response.status in [200, 203] and method in ["GET", "HEAD"]: # Don't cache 206's since we aren't going to handle byte range requests - if not response.has_key('content-location'): + if 'content-location' not in response: response['content-location'] = absolute_uri _updateCache(headers, response, content, self.cache, cachekey) @@ -1407,25 +1134,24 @@ def _normalize_headers(self, headers): def request(self, uri, method="GET", body=None, headers=None, redirections=DEFAULT_MAX_REDIRECTS, connection_type=None): """ Performs a single HTTP request. +The 'uri' is the URI of the HTTP resource and can begin +with either 'http' or 'https'. The value of 'uri' must be an absolute URI. - The 'uri' is the URI of the HTTP resource and can begin with either - 'http' or 'https'. The value of 'uri' must be an absolute URI. +The 'method' is the HTTP method to perform, such as GET, POST, DELETE, etc. +There is no restriction on the methods allowed. - The 'method' is the HTTP method to perform, such as GET, POST, DELETE, - etc. There is no restriction on the methods allowed. +The 'body' is the entity body to be sent with the request. It is a string +object. - The 'body' is the entity body to be sent with the request. It is a - string object. +Any extra headers that are to be sent with the request should be provided in the +'headers' dictionary. - Any extra headers that are to be sent with the request should be - provided in the 'headers' dictionary. +The maximum number of redirect to follow before raising an +exception is 'redirections. The default is 5. - The maximum number of redirect to follow before raising an - exception is 'redirections. The default is 5. - - The return value is a tuple of (response, content), the first - being and instance of the 'Response' class, the second being - a string that contains the response entity body. +The return value is a tuple of (response, content), the first +being and instance of the 'Response' class, the second being +a string that contains the response entity body. """ try: if headers is None: @@ -1433,7 +1159,7 @@ def request(self, uri, method="GET", body=None, headers=None, redirections=DEFAU else: headers = self._normalize_headers(headers) - if not headers.has_key('user-agent'): + if 'user-agent' not in headers: headers['user-agent'] = "Python-httplib2/%s (gzip)" % __version__ uri = iri2uri(uri) @@ -1444,8 +1170,6 @@ def request(self, uri, method="GET", body=None, headers=None, redirections=DEFAU scheme = 'https' authority = domain_port[0] - proxy_info = self._get_proxy_info(scheme, authority) - conn_key = scheme+":"+authority if conn_key in self.connections: conn = self.connections[conn_key] @@ -1453,48 +1177,44 @@ def request(self, uri, method="GET", body=None, headers=None, redirections=DEFAU if not connection_type: connection_type = SCHEME_TO_CONNECTION[scheme] certs = list(self.certificates.iter(authority)) - if scheme == 'https': + if issubclass(connection_type, HTTPSConnectionWithTimeout): if certs: conn = self.connections[conn_key] = connection_type( authority, key_file=certs[0][0], cert_file=certs[0][1], timeout=self.timeout, - proxy_info=proxy_info, + proxy_info=self.proxy_info, ca_certs=self.ca_certs, disable_ssl_certificate_validation= self.disable_ssl_certificate_validation) else: conn = self.connections[conn_key] = connection_type( authority, timeout=self.timeout, - proxy_info=proxy_info, + proxy_info=self.proxy_info, ca_certs=self.ca_certs, disable_ssl_certificate_validation= self.disable_ssl_certificate_validation) else: conn = self.connections[conn_key] = connection_type( authority, timeout=self.timeout, - proxy_info=proxy_info) + proxy_info=self.proxy_info) conn.set_debuglevel(debuglevel) if 'range' not in headers and 'accept-encoding' not in headers: headers['accept-encoding'] = 'gzip, deflate' - info = email.Message.Message() + info = email.message.Message() cached_value = None if self.cache: cachekey = defrag_uri cached_value = self.cache.get(cachekey) if cached_value: - # info = email.message_from_string(cached_value) - # - # Need to replace the line above with the kludge below - # to fix the non-existent bug not fixed in this - # bug report: http://mail.python.org/pipermail/python-bugs-list/2005-September/030289.html try: - info, content = cached_value.split('\r\n\r\n', 1) - feedparser = email.FeedParser.FeedParser() - feedparser.feed(info) - info = feedparser.close() - feedparser._parse = None + info, content = cached_value.split(b'\r\n\r\n', 1) + info = email.message_from_bytes(info) + for k, v in info.items(): + if v.startswith('=?') and v.endswith('?='): + info.replace_header(k, + str(*email.header.decode_header(v)[0])) except (IndexError, ValueError): self.cache.delete(cachekey) cachekey = None @@ -1502,7 +1222,7 @@ def request(self, uri, method="GET", body=None, headers=None, redirections=DEFAU else: cachekey = None - if method in self.optimistic_concurrency_methods and self.cache and info.has_key('etag') and not self.ignore_etag and 'if-match' not in headers: + if method in self.optimistic_concurrency_methods and self.cache and 'etag' in info and not self.ignore_etag and 'if-match' not in headers: # http://www.w3.org/1999/04/Editing/ headers['if-match'] = info['etag'] @@ -1519,14 +1239,14 @@ def request(self, uri, method="GET", body=None, headers=None, redirections=DEFAU key = '-varied-%s' % header value = info[key] if headers.get(header, None) != value: - cached_value = None - break + cached_value = None + break if cached_value and method in ["GET", "HEAD"] and self.cache and 'range' not in headers: - if info.has_key('-x-permanent-redirect-url'): + if '-x-permanent-redirect-url' in info: # Should cached permanent redirects be counted in our redirection count? For now, yes. if redirections <= 0: - raise RedirectLimit("Redirected more times than rediection_limit allows.", {}, "") + raise RedirectLimit("Redirected more times than redirection_limit allows.", {}, "") (response, new_content) = self.request( info['-x-permanent-redirect-url'], method='GET', headers=headers, redirections=redirections - 1) @@ -1546,16 +1266,16 @@ def request(self, uri, method="GET", body=None, headers=None, redirections=DEFAU if entry_disposition == "FRESH": if not cached_value: info['status'] = '504' - content = "" + content = b"" response = Response(info) if cached_value: response.fromcache = True return (response, content) if entry_disposition == "STALE": - if info.has_key('etag') and not self.ignore_etag and not 'if-none-match' in headers: + if 'etag' in info and not self.ignore_etag and not 'if-none-match' in headers: headers['if-none-match'] = info['etag'] - if info.has_key('last-modified') and not 'last-modified' in headers: + if 'last-modified' in info and not 'last-modified' in headers: headers['if-modified-since'] = info['last-modified'] elif entry_disposition == "TRANSPARENT": pass @@ -1585,13 +1305,13 @@ def request(self, uri, method="GET", body=None, headers=None, redirections=DEFAU content = new_content else: cc = _parse_cache_control(headers) - if cc.has_key('only-if-cached'): + if 'only-if-cached'in cc: info['status'] = '504' response = Response(info) - content = "" + content = b"" else: (response, content) = self._request(conn, authority, uri, request_uri, method, body, headers, redirections, cachekey) - except Exception, e: + except Exception as e: if self.force_exception_to_status_code: if isinstance(e, HttpLib2ErrorWithResponse): response = e.response @@ -1599,7 +1319,7 @@ def request(self, uri, method="GET", body=None, headers=None, redirections=DEFAU response.status = 500 response.reason = str(e) elif isinstance(e, socket.timeout): - content = "Request Timeout" + content = b"Request Timeout" response = Response({ "content-type": "text/plain", "status": "408", @@ -1607,7 +1327,7 @@ def request(self, uri, method="GET", body=None, headers=None, redirections=DEFAU }) response.reason = "Request Timeout" else: - content = str(e) + content = str(e).encode('utf-8') response = Response({ "content-type": "text/plain", "status": "400", @@ -1620,23 +1340,10 @@ def request(self, uri, method="GET", body=None, headers=None, redirections=DEFAU return (response, content) - def _get_proxy_info(self, scheme, authority): - """Return a ProxyInfo instance (or None) based on the scheme - and authority. - """ - hostname, port = urllib.splitport(authority) - proxy_info = self.proxy_info - if callable(proxy_info): - proxy_info = proxy_info(scheme) - - if (hasattr(proxy_info, 'applies_to') - and not proxy_info.applies_to(hostname)): - proxy_info = None - return proxy_info class Response(dict): - """An object more like email.Message than httplib.HTTPResponse.""" + """An object more like email.message than httplib.HTTPResponse.""" """Is this response from our local cache""" fromcache = False @@ -1653,28 +1360,31 @@ class Response(dict): previous = None def __init__(self, info): - # info is either an email.Message or + # info is either an email.message or # an httplib.HTTPResponse object. - if isinstance(info, httplib.HTTPResponse): + if isinstance(info, http.client.HTTPResponse): for key, value in info.getheaders(): - self[key.lower()] = value + key = key.lower() + prev = self.get(key) + if prev is not None: + value = ', '.join((prev, value)) + self[key] = value self.status = info.status self['status'] = str(self.status) self.reason = info.reason self.version = info.version - elif isinstance(info, email.Message.Message): - for key, value in info.items(): + elif isinstance(info, email.message.Message): + for key, value in list(info.items()): self[key.lower()] = value self.status = int(self['status']) else: - for key, value in info.iteritems(): + for key, value in info.items(): self[key.lower()] = value self.status = int(self.get('status', self.status)) - self.reason = self.get('reason', self.reason) def __getattr__(self, name): if name == 'dict': return self else: - raise AttributeError, name + raise AttributeError(name) diff --git a/httplib2/__init__.pyc-2.4 b/httplib2/__init__.pyc-2.4 deleted file mode 100644 index c1d1fe7..0000000 Binary files a/httplib2/__init__.pyc-2.4 and /dev/null differ diff --git a/httplib2/iri2uri.py b/httplib2/iri2uri.py index d88c91f..711377c 100644 --- a/httplib2/iri2uri.py +++ b/httplib2/iri2uri.py @@ -12,7 +12,7 @@ __history__ = """ """ -import urlparse +import urllib.parse # Convert an IRI to a URI following the rules in RFC 3987 @@ -57,7 +57,7 @@ def encode(c): if i < low: break if i >= low and i <= high: - retval = "".join(["%%%2X" % ord(o) for o in c.encode('utf-8')]) + retval = "".join(["%%%2X" % o for o in c.encode('utf-8')]) break return retval @@ -66,13 +66,13 @@ def iri2uri(uri): """Convert an IRI to a URI. Note that IRIs must be passed in a unicode strings. That is, do not utf-8 encode the IRI before passing it into the function.""" - if isinstance(uri ,unicode): - (scheme, authority, path, query, fragment) = urlparse.urlsplit(uri) - authority = authority.encode('idna') + if isinstance(uri ,str): + (scheme, authority, path, query, fragment) = urllib.parse.urlsplit(uri) + authority = authority.encode('idna').decode('utf-8') # For each character in 'ucschar' or 'iprivate' # 1. encode as utf-8 # 2. then %-encode each octet of that utf-8 - uri = urlparse.urlunsplit((scheme, authority, path, query, fragment)) + uri = urllib.parse.urlunsplit((scheme, authority, path, query, fragment)) uri = "".join([encode(c) for c in uri]) return uri @@ -84,26 +84,26 @@ class Test(unittest.TestCase): def test_uris(self): """Test that URIs are invariant under the transformation.""" invariant = [ - u"ftp://ftp.is.co.za/rfc/rfc1808.txt", - u"http://www.ietf.org/rfc/rfc2396.txt", - u"ldap://[2001:db8::7]/c=GB?objectClass?one", - u"mailto:John.Doe@example.com", - u"news:comp.infosystems.www.servers.unix", - u"tel:+1-816-555-1212", - u"telnet://192.0.2.16:80/", - u"urn:oasis:names:specification:docbook:dtd:xml:4.1.2" ] + "ftp://ftp.is.co.za/rfc/rfc1808.txt", + "http://www.ietf.org/rfc/rfc2396.txt", + "ldap://[2001:db8::7]/c=GB?objectClass?one", + "mailto:John.Doe@example.com", + "news:comp.infosystems.www.servers.unix", + "tel:+1-816-555-1212", + "telnet://192.0.2.16:80/", + "urn:oasis:names:specification:docbook:dtd:xml:4.1.2" ] for uri in invariant: self.assertEqual(uri, iri2uri(uri)) def test_iri(self): """ Test that the right type of escaping is done for each part of the URI.""" - self.assertEqual("http://xn--o3h.com/%E2%98%84", iri2uri(u"http://\N{COMET}.com/\N{COMET}")) - self.assertEqual("http://bitworking.org/?fred=%E2%98%84", iri2uri(u"http://bitworking.org/?fred=\N{COMET}")) - self.assertEqual("http://bitworking.org/#%E2%98%84", iri2uri(u"http://bitworking.org/#\N{COMET}")) - self.assertEqual("#%E2%98%84", iri2uri(u"#\N{COMET}")) - self.assertEqual("/fred?bar=%E2%98%9A#%E2%98%84", iri2uri(u"/fred?bar=\N{BLACK LEFT POINTING INDEX}#\N{COMET}")) - self.assertEqual("/fred?bar=%E2%98%9A#%E2%98%84", iri2uri(iri2uri(u"/fred?bar=\N{BLACK LEFT POINTING INDEX}#\N{COMET}"))) - self.assertNotEqual("/fred?bar=%E2%98%9A#%E2%98%84", iri2uri(u"/fred?bar=\N{BLACK LEFT POINTING INDEX}#\N{COMET}".encode('utf-8'))) + self.assertEqual("http://xn--o3h.com/%E2%98%84", iri2uri("http://\N{COMET}.com/\N{COMET}")) + self.assertEqual("http://bitworking.org/?fred=%E2%98%84", iri2uri("http://bitworking.org/?fred=\N{COMET}")) + self.assertEqual("http://bitworking.org/#%E2%98%84", iri2uri("http://bitworking.org/#\N{COMET}")) + self.assertEqual("#%E2%98%84", iri2uri("#\N{COMET}")) + self.assertEqual("/fred?bar=%E2%98%9A#%E2%98%84", iri2uri("/fred?bar=\N{BLACK LEFT POINTING INDEX}#\N{COMET}")) + self.assertEqual("/fred?bar=%E2%98%9A#%E2%98%84", iri2uri(iri2uri("/fred?bar=\N{BLACK LEFT POINTING INDEX}#\N{COMET}"))) + self.assertNotEqual("/fred?bar=%E2%98%9A#%E2%98%84", iri2uri("/fred?bar=\N{BLACK LEFT POINTING INDEX}#\N{COMET}".encode('utf-8'))) unittest.main() diff --git a/httplib2/iri2uri.pyc-2.4 b/httplib2/iri2uri.pyc-2.4 deleted file mode 100644 index 68d52aa..0000000 Binary files a/httplib2/iri2uri.pyc-2.4 and /dev/null differ diff --git a/httplib2/socks.py b/httplib2/socks.py deleted file mode 100644 index 0991f4c..0000000 --- a/httplib2/socks.py +++ /dev/null @@ -1,438 +0,0 @@ -"""SocksiPy - Python SOCKS module. -Version 1.00 - -Copyright 2006 Dan-Haim. All rights reserved. - -Redistribution and use in source and binary forms, with or without modification, -are permitted provided that the following conditions are met: -1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. -2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. -3. Neither the name of Dan Haim nor the names of his contributors may be used - to endorse or promote products derived from this software without specific - prior written permission. - -THIS SOFTWARE IS PROVIDED BY DAN HAIM "AS IS" AND ANY EXPRESS OR IMPLIED -WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO -EVENT SHALL DAN HAIM OR HIS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, -INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA -OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT -OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMANGE. - - -This module provides a standard socket-like interface for Python -for tunneling connections through SOCKS proxies. - -""" - -""" - -Minor modifications made by Christopher Gilbert (http://motomastyle.com/) -for use in PyLoris (http://pyloris.sourceforge.net/) - -Minor modifications made by Mario Vilas (http://breakingcode.wordpress.com/) -mainly to merge bug fixes found in Sourceforge - -""" - -import base64 -import socket -import struct -import sys - -if getattr(socket, 'socket', None) is None: - raise ImportError('socket.socket missing, proxy support unusable') - -PROXY_TYPE_SOCKS4 = 1 -PROXY_TYPE_SOCKS5 = 2 -PROXY_TYPE_HTTP = 3 -PROXY_TYPE_HTTP_NO_TUNNEL = 4 - -_defaultproxy = None -_orgsocket = socket.socket - -class ProxyError(Exception): pass -class GeneralProxyError(ProxyError): pass -class Socks5AuthError(ProxyError): pass -class Socks5Error(ProxyError): pass -class Socks4Error(ProxyError): pass -class HTTPError(ProxyError): pass - -_generalerrors = ("success", - "invalid data", - "not connected", - "not available", - "bad proxy type", - "bad input") - -_socks5errors = ("succeeded", - "general SOCKS server failure", - "connection not allowed by ruleset", - "Network unreachable", - "Host unreachable", - "Connection refused", - "TTL expired", - "Command not supported", - "Address type not supported", - "Unknown error") - -_socks5autherrors = ("succeeded", - "authentication is required", - "all offered authentication methods were rejected", - "unknown username or invalid password", - "unknown error") - -_socks4errors = ("request granted", - "request rejected or failed", - "request rejected because SOCKS server cannot connect to identd on the client", - "request rejected because the client program and identd report different user-ids", - "unknown error") - -def setdefaultproxy(proxytype=None, addr=None, port=None, rdns=True, username=None, password=None): - """setdefaultproxy(proxytype, addr[, port[, rdns[, username[, password]]]]) - Sets a default proxy which all further socksocket objects will use, - unless explicitly changed. - """ - global _defaultproxy - _defaultproxy = (proxytype, addr, port, rdns, username, password) - -def wrapmodule(module): - """wrapmodule(module) - Attempts to replace a module's socket library with a SOCKS socket. Must set - a default proxy using setdefaultproxy(...) first. - This will only work on modules that import socket directly into the namespace; - most of the Python Standard Library falls into this category. - """ - if _defaultproxy != None: - module.socket.socket = socksocket - else: - raise GeneralProxyError((4, "no proxy specified")) - -class socksocket(socket.socket): - """socksocket([family[, type[, proto]]]) -> socket object - Open a SOCKS enabled socket. The parameters are the same as - those of the standard socket init. In order for SOCKS to work, - you must specify family=AF_INET, type=SOCK_STREAM and proto=0. - """ - - def __init__(self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, _sock=None): - _orgsocket.__init__(self, family, type, proto, _sock) - if _defaultproxy != None: - self.__proxy = _defaultproxy - else: - self.__proxy = (None, None, None, None, None, None) - self.__proxysockname = None - self.__proxypeername = None - self.__httptunnel = True - - def __recvall(self, count): - """__recvall(count) -> data - Receive EXACTLY the number of bytes requested from the socket. - Blocks until the required number of bytes have been received. - """ - data = self.recv(count) - while len(data) < count: - d = self.recv(count-len(data)) - if not d: raise GeneralProxyError((0, "connection closed unexpectedly")) - data = data + d - return data - - def sendall(self, content, *args): - """ override socket.socket.sendall method to rewrite the header - for non-tunneling proxies if needed - """ - if not self.__httptunnel: - content = self.__rewriteproxy(content) - return super(socksocket, self).sendall(content, *args) - - def __rewriteproxy(self, header): - """ rewrite HTTP request headers to support non-tunneling proxies - (i.e. those which do not support the CONNECT method). - This only works for HTTP (not HTTPS) since HTTPS requires tunneling. - """ - host, endpt = None, None - hdrs = header.split("\r\n") - for hdr in hdrs: - if hdr.lower().startswith("host:"): - host = hdr - elif hdr.lower().startswith("get") or hdr.lower().startswith("post"): - endpt = hdr - if host and endpt: - hdrs.remove(host) - hdrs.remove(endpt) - host = host.split(" ")[1] - endpt = endpt.split(" ") - if (self.__proxy[4] != None and self.__proxy[5] != None): - hdrs.insert(0, self.__getauthheader()) - hdrs.insert(0, "Host: %s" % host) - hdrs.insert(0, "%s http://%s%s %s" % (endpt[0], host, endpt[1], endpt[2])) - return "\r\n".join(hdrs) - - def __getauthheader(self): - auth = self.__proxy[4] + ":" + self.__proxy[5] - return "Proxy-Authorization: Basic " + base64.b64encode(auth) - - def setproxy(self, proxytype=None, addr=None, port=None, rdns=True, username=None, password=None): - """setproxy(proxytype, addr[, port[, rdns[, username[, password]]]]) - Sets the proxy to be used. - proxytype - The type of the proxy to be used. Three types - are supported: PROXY_TYPE_SOCKS4 (including socks4a), - PROXY_TYPE_SOCKS5 and PROXY_TYPE_HTTP - addr - The address of the server (IP or DNS). - port - The port of the server. Defaults to 1080 for SOCKS - servers and 8080 for HTTP proxy servers. - rdns - Should DNS queries be preformed on the remote side - (rather than the local side). The default is True. - Note: This has no effect with SOCKS4 servers. - username - Username to authenticate with to the server. - The default is no authentication. - password - Password to authenticate with to the server. - Only relevant when username is also provided. - """ - self.__proxy = (proxytype, addr, port, rdns, username, password) - - def __negotiatesocks5(self, destaddr, destport): - """__negotiatesocks5(self,destaddr,destport) - Negotiates a connection through a SOCKS5 server. - """ - # First we'll send the authentication packages we support. - if (self.__proxy[4]!=None) and (self.__proxy[5]!=None): - # The username/password details were supplied to the - # setproxy method so we support the USERNAME/PASSWORD - # authentication (in addition to the standard none). - self.sendall(struct.pack('BBBB', 0x05, 0x02, 0x00, 0x02)) - else: - # No username/password were entered, therefore we - # only support connections with no authentication. - self.sendall(struct.pack('BBB', 0x05, 0x01, 0x00)) - # We'll receive the server's response to determine which - # method was selected - chosenauth = self.__recvall(2) - if chosenauth[0:1] != chr(0x05).encode(): - self.close() - raise GeneralProxyError((1, _generalerrors[1])) - # Check the chosen authentication method - if chosenauth[1:2] == chr(0x00).encode(): - # No authentication is required - pass - elif chosenauth[1:2] == chr(0x02).encode(): - # Okay, we need to perform a basic username/password - # authentication. - self.sendall(chr(0x01).encode() + chr(len(self.__proxy[4])) + self.__proxy[4] + chr(len(self.__proxy[5])) + self.__proxy[5]) - authstat = self.__recvall(2) - if authstat[0:1] != chr(0x01).encode(): - # Bad response - self.close() - raise GeneralProxyError((1, _generalerrors[1])) - if authstat[1:2] != chr(0x00).encode(): - # Authentication failed - self.close() - raise Socks5AuthError((3, _socks5autherrors[3])) - # Authentication succeeded - else: - # Reaching here is always bad - self.close() - if chosenauth[1] == chr(0xFF).encode(): - raise Socks5AuthError((2, _socks5autherrors[2])) - else: - raise GeneralProxyError((1, _generalerrors[1])) - # Now we can request the actual connection - req = struct.pack('BBB', 0x05, 0x01, 0x00) - # If the given destination address is an IP address, we'll - # use the IPv4 address request even if remote resolving was specified. - try: - ipaddr = socket.inet_aton(destaddr) - req = req + chr(0x01).encode() + ipaddr - except socket.error: - # Well it's not an IP number, so it's probably a DNS name. - if self.__proxy[3]: - # Resolve remotely - ipaddr = None - req = req + chr(0x03).encode() + chr(len(destaddr)).encode() + destaddr - else: - # Resolve locally - ipaddr = socket.inet_aton(socket.gethostbyname(destaddr)) - req = req + chr(0x01).encode() + ipaddr - req = req + struct.pack(">H", destport) - self.sendall(req) - # Get the response - resp = self.__recvall(4) - if resp[0:1] != chr(0x05).encode(): - self.close() - raise GeneralProxyError((1, _generalerrors[1])) - elif resp[1:2] != chr(0x00).encode(): - # Connection failed - self.close() - if ord(resp[1:2])<=8: - raise Socks5Error((ord(resp[1:2]), _socks5errors[ord(resp[1:2])])) - else: - raise Socks5Error((9, _socks5errors[9])) - # Get the bound address/port - elif resp[3:4] == chr(0x01).encode(): - boundaddr = self.__recvall(4) - elif resp[3:4] == chr(0x03).encode(): - resp = resp + self.recv(1) - boundaddr = self.__recvall(ord(resp[4:5])) - else: - self.close() - raise GeneralProxyError((1,_generalerrors[1])) - boundport = struct.unpack(">H", self.__recvall(2))[0] - self.__proxysockname = (boundaddr, boundport) - if ipaddr != None: - self.__proxypeername = (socket.inet_ntoa(ipaddr), destport) - else: - self.__proxypeername = (destaddr, destport) - - def getproxysockname(self): - """getsockname() -> address info - Returns the bound IP address and port number at the proxy. - """ - return self.__proxysockname - - def getproxypeername(self): - """getproxypeername() -> address info - Returns the IP and port number of the proxy. - """ - return _orgsocket.getpeername(self) - - def getpeername(self): - """getpeername() -> address info - Returns the IP address and port number of the destination - machine (note: getproxypeername returns the proxy) - """ - return self.__proxypeername - - def __negotiatesocks4(self,destaddr,destport): - """__negotiatesocks4(self,destaddr,destport) - Negotiates a connection through a SOCKS4 server. - """ - # Check if the destination address provided is an IP address - rmtrslv = False - try: - ipaddr = socket.inet_aton(destaddr) - except socket.error: - # It's a DNS name. Check where it should be resolved. - if self.__proxy[3]: - ipaddr = struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01) - rmtrslv = True - else: - ipaddr = socket.inet_aton(socket.gethostbyname(destaddr)) - # Construct the request packet - req = struct.pack(">BBH", 0x04, 0x01, destport) + ipaddr - # The username parameter is considered userid for SOCKS4 - if self.__proxy[4] != None: - req = req + self.__proxy[4] - req = req + chr(0x00).encode() - # DNS name if remote resolving is required - # NOTE: This is actually an extension to the SOCKS4 protocol - # called SOCKS4A and may not be supported in all cases. - if rmtrslv: - req = req + destaddr + chr(0x00).encode() - self.sendall(req) - # Get the response from the server - resp = self.__recvall(8) - if resp[0:1] != chr(0x00).encode(): - # Bad data - self.close() - raise GeneralProxyError((1,_generalerrors[1])) - if resp[1:2] != chr(0x5A).encode(): - # Server returned an error - self.close() - if ord(resp[1:2]) in (91, 92, 93): - self.close() - raise Socks4Error((ord(resp[1:2]), _socks4errors[ord(resp[1:2]) - 90])) - else: - raise Socks4Error((94, _socks4errors[4])) - # Get the bound address/port - self.__proxysockname = (socket.inet_ntoa(resp[4:]), struct.unpack(">H", resp[2:4])[0]) - if rmtrslv != None: - self.__proxypeername = (socket.inet_ntoa(ipaddr), destport) - else: - self.__proxypeername = (destaddr, destport) - - def __negotiatehttp(self, destaddr, destport): - """__negotiatehttp(self,destaddr,destport) - Negotiates a connection through an HTTP server. - """ - # If we need to resolve locally, we do this now - if not self.__proxy[3]: - addr = socket.gethostbyname(destaddr) - else: - addr = destaddr - headers = ["CONNECT ", addr, ":", str(destport), " HTTP/1.1\r\n"] - headers += ["Host: ", destaddr, "\r\n"] - if (self.__proxy[4] != None and self.__proxy[5] != None): - headers += [self.__getauthheader(), "\r\n"] - headers.append("\r\n") - self.sendall("".join(headers).encode()) - # We read the response until we get the string "\r\n\r\n" - resp = self.recv(1) - while resp.find("\r\n\r\n".encode()) == -1: - resp = resp + self.recv(1) - # We just need the first line to check if the connection - # was successful - statusline = resp.splitlines()[0].split(" ".encode(), 2) - if statusline[0] not in ("HTTP/1.0".encode(), "HTTP/1.1".encode()): - self.close() - raise GeneralProxyError((1, _generalerrors[1])) - try: - statuscode = int(statusline[1]) - except ValueError: - self.close() - raise GeneralProxyError((1, _generalerrors[1])) - if statuscode != 200: - self.close() - raise HTTPError((statuscode, statusline[2])) - self.__proxysockname = ("0.0.0.0", 0) - self.__proxypeername = (addr, destport) - - def connect(self, destpair): - """connect(self, despair) - Connects to the specified destination through a proxy. - destpar - A tuple of the IP/DNS address and the port number. - (identical to socket's connect). - To select the proxy server use setproxy(). - """ - # Do a minimal input check first - if (not type(destpair) in (list,tuple)) or (len(destpair) < 2) or (not isinstance(destpair[0], basestring)) or (type(destpair[1]) != int): - raise GeneralProxyError((5, _generalerrors[5])) - if self.__proxy[0] == PROXY_TYPE_SOCKS5: - if self.__proxy[2] != None: - portnum = self.__proxy[2] - else: - portnum = 1080 - _orgsocket.connect(self, (self.__proxy[1], portnum)) - self.__negotiatesocks5(destpair[0], destpair[1]) - elif self.__proxy[0] == PROXY_TYPE_SOCKS4: - if self.__proxy[2] != None: - portnum = self.__proxy[2] - else: - portnum = 1080 - _orgsocket.connect(self,(self.__proxy[1], portnum)) - self.__negotiatesocks4(destpair[0], destpair[1]) - elif self.__proxy[0] == PROXY_TYPE_HTTP: - if self.__proxy[2] != None: - portnum = self.__proxy[2] - else: - portnum = 8080 - _orgsocket.connect(self,(self.__proxy[1], portnum)) - self.__negotiatehttp(destpair[0], destpair[1]) - elif self.__proxy[0] == PROXY_TYPE_HTTP_NO_TUNNEL: - if self.__proxy[2] != None: - portnum = self.__proxy[2] - else: - portnum = 8080 - _orgsocket.connect(self,(self.__proxy[1],portnum)) - if destpair[1] == 443: - self.__negotiatehttp(destpair[0],destpair[1]) - else: - self.__httptunnel = False - elif self.__proxy[0] == None: - _orgsocket.connect(self, (destpair[0], destpair[1])) - else: - raise GeneralProxyError((4, _generalerrors[4])) diff --git a/httplib2/socks.pyc-2.4 b/httplib2/socks.pyc-2.4 deleted file mode 100644 index f0cc7c3..0000000 Binary files a/httplib2/socks.pyc-2.4 and /dev/null differ diff --git a/httplib2/test/__init__.py b/httplib2/test/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/httplib2/test/brokensocket/socket.py b/httplib2/test/brokensocket/socket.py deleted file mode 100644 index ff7c0b7..0000000 --- a/httplib2/test/brokensocket/socket.py +++ /dev/null @@ -1 +0,0 @@ -from realsocket import gaierror, error, getaddrinfo, SOCK_STREAM diff --git a/httplib2/test/functional/test_proxies.py b/httplib2/test/functional/test_proxies.py deleted file mode 100644 index 0b7880f..0000000 --- a/httplib2/test/functional/test_proxies.py +++ /dev/null @@ -1,88 +0,0 @@ -import unittest -import errno -import os -import signal -import subprocess -import tempfile - -import nose - -import httplib2 -from httplib2 import socks -from httplib2.test import miniserver - -tinyproxy_cfg = """ -User "%(user)s" -Port %(port)s -Listen 127.0.0.1 -PidFile "%(pidfile)s" -LogFile "%(logfile)s" -MaxClients 2 -StartServers 1 -LogLevel Info -""" - - -class FunctionalProxyHttpTest(unittest.TestCase): - def setUp(self): - if not socks: - raise nose.SkipTest('socks module unavailable') - if not subprocess: - raise nose.SkipTest('subprocess module unavailable') - - # start a short-lived miniserver so we can get a likely port - # for the proxy - self.httpd, self.proxyport = miniserver.start_server( - miniserver.ThisDirHandler) - self.httpd.shutdown() - self.httpd, self.port = miniserver.start_server( - miniserver.ThisDirHandler) - - self.pidfile = tempfile.mktemp() - self.logfile = tempfile.mktemp() - fd, self.conffile = tempfile.mkstemp() - f = os.fdopen(fd, 'w') - our_cfg = tinyproxy_cfg % {'user': os.getlogin(), - 'pidfile': self.pidfile, - 'port': self.proxyport, - 'logfile': self.logfile} - f.write(our_cfg) - f.close() - try: - # TODO use subprocess.check_call when 2.4 is dropped - ret = subprocess.call(['tinyproxy', '-c', self.conffile]) - self.assertEqual(0, ret) - except OSError, e: - if e.errno == errno.ENOENT: - raise nose.SkipTest('tinyproxy not available') - raise - - def tearDown(self): - self.httpd.shutdown() - try: - pid = int(open(self.pidfile).read()) - os.kill(pid, signal.SIGTERM) - except OSError, e: - if e.errno == errno.ESRCH: - print '\n\n\nTinyProxy Failed to start, log follows:' - print open(self.logfile).read() - print 'end tinyproxy log\n\n\n' - raise - map(os.unlink, (self.pidfile, - self.logfile, - self.conffile)) - - def testSimpleProxy(self): - proxy_info = httplib2.ProxyInfo(socks.PROXY_TYPE_HTTP, - 'localhost', self.proxyport) - client = httplib2.Http(proxy_info=proxy_info) - src = 'miniserver.py' - response, body = client.request('http://localhost:%d/%s' % - (self.port, src)) - self.assertEqual(response.status, 200) - self.assertEqual(body, open(os.path.join(miniserver.HERE, src)).read()) - lf = open(self.logfile).read() - expect = ('Established connection to host "127.0.0.1" ' - 'using file descriptor') - self.assertTrue(expect in lf, - 'tinyproxy did not proxy a request for miniserver') diff --git a/httplib2/test/miniserver.py b/httplib2/test/miniserver.py deleted file mode 100644 index e32bf5e..0000000 --- a/httplib2/test/miniserver.py +++ /dev/null @@ -1,100 +0,0 @@ -import logging -import os -import select -import SimpleHTTPServer -import SocketServer -import threading - -HERE = os.path.dirname(__file__) -logger = logging.getLogger(__name__) - - -class ThisDirHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): - def translate_path(self, path): - path = path.split('?', 1)[0].split('#', 1)[0] - return os.path.join(HERE, *filter(None, path.split('/'))) - - def log_message(self, s, *args): - # output via logging so nose can catch it - logger.info(s, *args) - - -class ShutdownServer(SocketServer.TCPServer): - """Mixin that allows serve_forever to be shut down. - - The methods in this mixin are backported from SocketServer.py in the Python - 2.6.4 standard library. The mixin is unnecessary in 2.6 and later, when - BaseServer supports the shutdown method directly. - """ - - def __init__(self, *args, **kwargs): - SocketServer.TCPServer.__init__(self, *args, **kwargs) - self.__is_shut_down = threading.Event() - self.__serving = False - - def serve_forever(self, poll_interval=0.1): - """Handle one request at a time until shutdown. - - Polls for shutdown every poll_interval seconds. Ignores - self.timeout. If you need to do periodic tasks, do them in - another thread. - """ - self.__serving = True - self.__is_shut_down.clear() - while self.__serving: - r, w, e = select.select([self.socket], [], [], poll_interval) - if r: - self._handle_request_noblock() - self.__is_shut_down.set() - - def shutdown(self): - """Stops the serve_forever loop. - - Blocks until the loop has finished. This must be called while - serve_forever() is running in another thread, or it will deadlock. - """ - self.__serving = False - self.__is_shut_down.wait() - - def handle_request(self): - """Handle one request, possibly blocking. - - Respects self.timeout. - """ - # Support people who used socket.settimeout() to escape - # handle_request before self.timeout was available. - timeout = self.socket.gettimeout() - if timeout is None: - timeout = self.timeout - elif self.timeout is not None: - timeout = min(timeout, self.timeout) - fd_sets = select.select([self], [], [], timeout) - if not fd_sets[0]: - self.handle_timeout() - return - self._handle_request_noblock() - - def _handle_request_noblock(self): - """Handle one request, without blocking. - - I assume that select.select has returned that the socket is - readable before this function was called, so there should be - no risk of blocking in get_request(). - """ - try: - request, client_address = self.get_request() - except socket.error: - return - if self.verify_request(request, client_address): - try: - self.process_request(request, client_address) - except: - self.handle_error(request, client_address) - self.close_request(request) - - -def start_server(handler): - httpd = ShutdownServer(("", 0), handler) - threading.Thread(target=httpd.serve_forever).start() - _, port = httpd.socket.getsockname() - return httpd, port diff --git a/httplib2/test/smoke_test.py b/httplib2/test/smoke_test.py deleted file mode 100644 index 9f1e6f0..0000000 --- a/httplib2/test/smoke_test.py +++ /dev/null @@ -1,23 +0,0 @@ -import os -import unittest - -import httplib2 - -from httplib2.test import miniserver - - -class HttpSmokeTest(unittest.TestCase): - def setUp(self): - self.httpd, self.port = miniserver.start_server( - miniserver.ThisDirHandler) - - def tearDown(self): - self.httpd.shutdown() - - def testGetFile(self): - client = httplib2.Http() - src = 'miniserver.py' - response, body = client.request('http://localhost:%d/%s' % - (self.port, src)) - self.assertEqual(response.status, 200) - self.assertEqual(body, open(os.path.join(miniserver.HERE, src)).read()) diff --git a/httplib2/test/test_no_socket.py b/httplib2/test/test_no_socket.py deleted file mode 100644 index 66ba056..0000000 --- a/httplib2/test/test_no_socket.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Tests for httplib2 when the socket module is missing. - -This helps ensure compatibility with environments such as AppEngine. -""" -import os -import sys -import unittest - -import httplib2 - -class MissingSocketTest(unittest.TestCase): - def setUp(self): - self._oldsocks = httplib2.socks - httplib2.socks = None - - def tearDown(self): - httplib2.socks = self._oldsocks - - def testProxyDisabled(self): - proxy_info = httplib2.ProxyInfo('blah', - 'localhost', 0) - client = httplib2.Http(proxy_info=proxy_info) - self.assertRaises(httplib2.ProxiesUnavailableError, - client.request, 'http://localhost:-1/') diff --git a/oauth2client/__init__.py b/oauth2client/__init__.py index 7e4e122..f992cff 100644 --- a/oauth2client/__init__.py +++ b/oauth2client/__init__.py @@ -1,5 +1,8 @@ -__version__ = "1.1" +"""Client library for using OAuth2, especially with Google APIs.""" + +__version__ = '1.4.7' GOOGLE_AUTH_URI = 'https://accounts.google.com/o/oauth2/auth' +GOOGLE_DEVICE_URI = 'https://accounts.google.com/o/oauth2/device/code' GOOGLE_REVOKE_URI = 'https://accounts.google.com/o/oauth2/revoke' GOOGLE_TOKEN_URI = 'https://accounts.google.com/o/oauth2/token' diff --git a/oauth2client/appengine.py b/oauth2client/appengine.py index a6d88df..00fe985 100644 --- a/oauth2client/appengine.py +++ b/oauth2client/appengine.py @@ -1,4 +1,4 @@ -# Copyright (C) 2010 Google Inc. +# Copyright 2014 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,13 +19,14 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)' -import base64 import cgi -import httplib2 +import json import logging import os import pickle -import time +import threading + +import httplib2 from google.appengine.api import app_identity from google.appengine.api import memcache @@ -40,7 +41,6 @@ from oauth2client import clientsecrets from oauth2client import util from oauth2client import xsrfutil -from oauth2client.anyjson import simplejson from oauth2client.client import AccessTokenRefreshError from oauth2client.client import AssertionCredentials from oauth2client.client import Credentials @@ -158,15 +158,20 @@ def __init__(self, scope, **kwargs): Args: scope: string or iterable of strings, scope(s) of the credentials being requested. + **kwargs: optional keyword args, including: + service_account_id: service account id of the application. If None or + unspecified, the default service account for the app is used. """ self.scope = util.scopes_to_string(scope) + self._kwargs = kwargs + self.service_account_id = kwargs.get('service_account_id', None) # Assertion type is no longer used, but still in the parent class signature. super(AppAssertionCredentials, self).__init__(None) @classmethod - def from_json(cls, json): - data = simplejson.loads(json) + def from_json(cls, json_data): + data = json.loads(json_data) return AppAssertionCredentials(data['scope']) def _refresh(self, http_request): @@ -185,11 +190,22 @@ def _refresh(self, http_request): """ try: scopes = self.scope.split() - (token, _) = app_identity.get_access_token(scopes) - except app_identity.Error, e: + (token, _) = app_identity.get_access_token( + scopes, service_account_id=self.service_account_id) + except app_identity.Error as e: raise AccessTokenRefreshError(str(e)) self.access_token = token + @property + def serialization_data(self): + raise NotImplementedError('Cannot serialize credentials for AppEngine.') + + def create_scoped_required(self): + return not self.scope + + def create_scoped(self, scopes): + return AppAssertionCredentials(scopes, **self._kwargs) + class FlowProperty(db.Property): """App Engine datastore Property for Flow. @@ -367,7 +383,7 @@ class StorageByKeyName(Storage): """ @util.positional(4) - def __init__(self, model, key_name, property_name, cache=None): + def __init__(self, model, key_name, property_name, cache=None, user=None): """Constructor for Storage. Args: @@ -378,7 +394,14 @@ def __init__(self, model, key_name, property_name, cache=None): cache: memcache, a write-through cache to put in front of the datastore. If the model you are using is an NDB model, using a cache will be redundant since the model uses an instance cache and memcache for you. + user: users.User object, optional. Can be used to grab user ID as a + key_name if no key name is specified. """ + if key_name is None: + if user is None: + raise ValueError('StorageByKeyName called with no key name or user.') + key_name = user.user_id() + self._model = model self._key_name = key_name self._property_name = property_name @@ -426,28 +449,30 @@ def _delete_entity(self): entity_key = db.Key.from_path(self._model.kind(), self._key_name) db.delete(entity_key) + @db.non_transactional(allow_existing=True) def locked_get(self): """Retrieve Credential from datastore. Returns: oauth2client.Credentials """ + credentials = None if self._cache: json = self._cache.get(self._key_name) if json: - return Credentials.new_from_json(json) - - credentials = None - entity = self._get_entity() - if entity is not None: - credentials = getattr(entity, self._property_name) - if credentials and hasattr(credentials, 'set_store'): - credentials.set_store(self) + credentials = Credentials.new_from_json(json) + if credentials is None: + entity = self._get_entity() + if entity is not None: + credentials = getattr(entity, self._property_name) if self._cache: self._cache.set(self._key_name, credentials.to_json()) + if credentials and hasattr(credentials, 'set_store'): + credentials.set_store(self) return credentials + @db.non_transactional(allow_existing=True) def locked_put(self, credentials): """Write a Credentials to the datastore. @@ -460,6 +485,7 @@ def locked_put(self, credentials): if self._cache: self._cache.set(self._key_name, credentials.to_json()) + @db.non_transactional(allow_existing=True) def locked_delete(self): """Delete Credential from datastore.""" @@ -545,16 +571,14 @@ class OAuth2Decorator(object): Instantiate and then use with oauth_required or oauth_aware as decorators on webapp.RequestHandler methods. - Example: + :: decorator = OAuth2Decorator( client_id='837...ent.com', client_secret='Qh...wwI', scope='https://www.googleapis.com/auth/plus') - class MainHandler(webapp.RequestHandler): - @decorator.oauth_required def get(self): http = decorator.http() @@ -563,6 +587,37 @@ def get(self): """ + def set_credentials(self, credentials): + self._tls.credentials = credentials + + def get_credentials(self): + """A thread local Credentials object. + + Returns: + A client.Credentials object, or None if credentials hasn't been set in + this thread yet, which may happen when calling has_credentials inside + oauth_aware. + """ + return getattr(self._tls, 'credentials', None) + + credentials = property(get_credentials, set_credentials) + + def set_flow(self, flow): + self._tls.flow = flow + + def get_flow(self): + """A thread local Flow object. + + Returns: + A credentials.Flow object, or None if the flow hasn't been set in this + thread yet, which happens in _create_flow() since Flows are created + lazily. + """ + return getattr(self._tls, 'flow', None) + + flow = property(get_flow, set_flow) + + @util.positional(4) def __init__(self, client_id, client_secret, scope, auth_uri=GOOGLE_AUTH_URI, @@ -572,6 +627,9 @@ def __init__(self, client_id, client_secret, scope, message=None, callback_path='/oauth2callback', token_response_param=None, + _storage_class=StorageByKeyName, + _credentials_class=CredentialsModel, + _credentials_property_name='credentials', **kwargs): """Constructor for OAuth2Decorator @@ -598,9 +656,21 @@ def __init__(self, client_id, client_secret, scope, to the access token request will be encoded and included in this query parameter in the callback URI. This is useful with providers (e.g. wordpress.com) that include extra fields that the client may want. - **kwargs: dict, Keyword arguments are be passed along as kwargs to the - OAuth2WebServerFlow constructor. + _storage_class: "Protected" keyword argument not typically provided to + this constructor. A storage class to aid in storing a Credentials object + for a user in the datastore. Defaults to StorageByKeyName. + _credentials_class: "Protected" keyword argument not typically provided to + this constructor. A db or ndb Model class to hold credentials. Defaults + to CredentialsModel. + _credentials_property_name: "Protected" keyword argument not typically + provided to this constructor. A string indicating the name of the field + on the _credentials_class where a Credentials object will be stored. + Defaults to 'credentials'. + **kwargs: dict, Keyword arguments are passed along as kwargs to + the OAuth2WebServerFlow constructor. + """ + self._tls = threading.local() self.flow = None self.credentials = None self._client_id = client_id @@ -615,6 +685,9 @@ def __init__(self, client_id, client_secret, scope, self._in_error = False self._callback_path = callback_path self._token_response_param = token_response_param + self._storage_class = _storage_class + self._credentials_class = _credentials_class + self._credentials_property_name = _credentials_property_name def _display_error_message(self, request_handler): request_handler.response.out.write('') @@ -648,15 +721,19 @@ def check_oauth(request_handler, *args, **kwargs): # Store the request URI in 'state' so we can use it later self.flow.params['state'] = _build_state_value(request_handler, user) - self.credentials = StorageByKeyName( - CredentialsModel, user.user_id(), 'credentials').get() + self.credentials = self._storage_class( + self._credentials_class, None, + self._credentials_property_name, user=user).get() if not self.has_credentials(): return request_handler.redirect(self.authorize_url()) try: - return method(request_handler, *args, **kwargs) + resp = method(request_handler, *args, **kwargs) except AccessTokenRefreshError: return request_handler.redirect(self.authorize_url()) + finally: + self.credentials = None + return resp return check_oauth @@ -710,11 +787,17 @@ def setup_oauth(request_handler, *args, **kwargs): self._create_flow(request_handler) self.flow.params['state'] = _build_state_value(request_handler, user) - self.credentials = StorageByKeyName( - CredentialsModel, user.user_id(), 'credentials').get() - return method(request_handler, *args, **kwargs) + self.credentials = self._storage_class( + self._credentials_class, None, + self._credentials_property_name, user=user).get() + try: + resp = method(request_handler, *args, **kwargs) + finally: + self.credentials = None + return resp return setup_oauth + def has_credentials(self): """True if for the logged in user there are valid access Credentials. @@ -732,14 +815,18 @@ def authorize_url(self): url = self.flow.step1_get_authorize_url() return str(url) - def http(self): + def http(self, *args, **kwargs): """Returns an authorized http instance. Must only be called from within an @oauth_required decorated method, or from within an @oauth_aware decorated method where has_credentials() returns True. + + Args: + *args: Positional arguments passed to httplib2.Http constructor. + **kwargs: Positional arguments passed to httplib2.Http constructor. """ - return self.credentials.authorize(httplib2.Http()) + return self.credentials.authorize(httplib2.Http(*args, **kwargs)) @property def callback_path(self): @@ -758,7 +845,8 @@ def callback_path(self): def callback_handler(self): """RequestHandler for the OAuth 2.0 redirect callback. - Usage: + Usage:: + app = webapp.WSGIApplication([ ('/index', MyIndexHandler), ..., @@ -785,13 +873,14 @@ def get(self): user = users.get_current_user() decorator._create_flow(self) credentials = decorator.flow.step2_exchange(self.request.params) - StorageByKeyName( - CredentialsModel, user.user_id(), 'credentials').put(credentials) + decorator._storage_class( + decorator._credentials_class, None, + decorator._credentials_property_name, user=user).put(credentials) redirect_uri = _parse_state_value(str(self.request.get('state')), user) if decorator._token_response_param and credentials.token_response: - resp_json = simplejson.dumps(credentials.token_response) + resp_json = json.dumps(credentials.token_response) redirect_uri = util._add_query_parameter( redirect_uri, decorator._token_response_param, resp_json) @@ -820,24 +909,23 @@ class OAuth2DecoratorFromClientSecrets(OAuth2Decorator): Uses a clientsecrets file as the source for all the information when constructing an OAuth2Decorator. - Example: + :: decorator = OAuth2DecoratorFromClientSecrets( os.path.join(os.path.dirname(__file__), 'client_secrets.json') scope='https://www.googleapis.com/auth/plus') - class MainHandler(webapp.RequestHandler): - @decorator.oauth_required def get(self): http = decorator.http() # http is authorized with the user's Credentials and can be used # in API calls + """ @util.positional(3) - def __init__(self, filename, scope, message=None, cache=None): + def __init__(self, filename, scope, message=None, cache=None, **kwargs): """Constructor Args: @@ -850,17 +938,20 @@ def __init__(self, filename, scope, message=None, cache=None): decorator. cache: An optional cache service client that implements get() and set() methods. See clientsecrets.loadfile() for details. + **kwargs: dict, Keyword arguments are passed along as kwargs to + the OAuth2WebServerFlow constructor. """ client_type, client_info = clientsecrets.loadfile(filename, cache=cache) if client_type not in [ clientsecrets.TYPE_WEB, clientsecrets.TYPE_INSTALLED]: raise InvalidClientSecretsError( - 'OAuth2Decorator doesn\'t support this OAuth 2.0 flow.') - constructor_kwargs = { - 'auth_uri': client_info['auth_uri'], - 'token_uri': client_info['token_uri'], - 'message': message, - } + "OAuth2Decorator doesn't support this OAuth 2.0 flow.") + constructor_kwargs = dict(kwargs) + constructor_kwargs.update({ + 'auth_uri': client_info['auth_uri'], + 'token_uri': client_info['token_uri'], + 'message': message, + }) revoke_uri = client_info.get('revoke_uri') if revoke_uri is not None: constructor_kwargs['revoke_uri'] = revoke_uri diff --git a/oauth2client/client.py b/oauth2client/client.py index 53f53ba..f3ec3a3 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -1,4 +1,4 @@ -# Copyright (C) 2010 Google Inc. +# Copyright 2014 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,22 +20,25 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)' import base64 -import clientsecrets +import collections import copy import datetime -import httplib2 +import json import logging import os +import socket import sys import time -import urllib -import urlparse +import six +from six.moves import urllib +import httplib2 +from oauth2client import clientsecrets from oauth2client import GOOGLE_AUTH_URI +from oauth2client import GOOGLE_DEVICE_URI from oauth2client import GOOGLE_REVOKE_URI from oauth2client import GOOGLE_TOKEN_URI from oauth2client import util -from oauth2client.anyjson import simplejson HAS_OPENSSL = False HAS_CRYPTO = False @@ -47,18 +50,16 @@ except ImportError: pass -try: - from urlparse import parse_qsl -except ImportError: - from cgi import parse_qsl - logger = logging.getLogger(__name__) # Expiry is stored in RFC3339 UTC format EXPIRY_FORMAT = '%Y-%m-%dT%H:%M:%SZ' # Which certs to use to validate id_tokens received. -ID_TOKEN_VERIFICATON_CERTS = 'https://www.googleapis.com/oauth2/v1/certs' +ID_TOKEN_VERIFICATION_CERTS = 'https://www.googleapis.com/oauth2/v1/certs' +# This symbol previously had a typo in the name; we keep the old name +# around for now, but will remove it in the future. +ID_TOKEN_VERIFICATON_CERTS = ID_TOKEN_VERIFICATION_CERTS # Constant to use for the out of band OAuth 2.0 flow. OOB_CALLBACK_URN = 'urn:ietf:wg:oauth:2.0:oob' @@ -66,6 +67,39 @@ # Google Data client libraries may need to set this to [401, 403]. REFRESH_STATUS_CODES = [401] +# The value representing user credentials. +AUTHORIZED_USER = 'authorized_user' + +# The value representing service account credentials. +SERVICE_ACCOUNT = 'service_account' + +# The environment variable pointing the file with local +# Application Default Credentials. +GOOGLE_APPLICATION_CREDENTIALS = 'GOOGLE_APPLICATION_CREDENTIALS' + +# The error message we show users when we can't find the Application +# Default Credentials. +ADC_HELP_MSG = ( + 'The Application Default Credentials are not available. They are available ' + 'if running in Google Compute Engine. Otherwise, the environment variable ' + + GOOGLE_APPLICATION_CREDENTIALS + + ' must be defined pointing to a file defining the credentials. See ' + 'https://developers.google.com/accounts/docs/application-default-credentials' # pylint:disable=line-too-long + ' for more information.') + +# The access token along with the seconds in which it expires. +AccessTokenInfo = collections.namedtuple( + 'AccessTokenInfo', ['access_token', 'expires_in']) + +DEFAULT_ENV_NAME = 'UNKNOWN' + +# If set to True _get_environment avoid GCE check (_detect_gce_environment) +NO_GCE_CHECK = os.environ.setdefault('NO_GCE_CHECK', 'False') + +class SETTINGS(object): + """Settings namespace for globally defined values.""" + env_name = None + class Error(Exception): """Base error for this module.""" @@ -92,13 +126,25 @@ class AccessTokenCredentialsError(Error): class VerifyJwtTokenError(Error): - """Could on retrieve certificates for validation.""" + """Could not retrieve certificates for validation.""" class NonAsciiHeaderError(Error): """Header names and values must be ASCII strings.""" +class ApplicationDefaultCredentialsError(Error): + """Error retrieving the Application Default Credentials.""" + + +class OAuth2DeviceCodeError(Error): + """Error trying to retrieve a device code.""" + + +class CryptoUnavailableError(Error, NotImplementedError): + """Raised when a crypto library is required, but none is available.""" + + def _abstract(): raise NotImplementedError('You need to override this function') @@ -126,11 +172,12 @@ class Credentials(object): an HTTP transport. Subclasses must also specify a classmethod named 'from_json' that takes a JSON - string as input and returns an instaniated Credentials object. + string as input and returns an instantiated Credentials object. """ NON_SERIALIZED_MEMBERS = ['store'] + def authorize(self, http): """Take an httplib2.Http instance (or equivalent) and authorizes it. @@ -144,6 +191,7 @@ def authorize(self, http): """ _abstract() + def refresh(self, http): """Forces a refresh of the access_token. @@ -153,6 +201,7 @@ def refresh(self, http): """ _abstract() + def revoke(self, http): """Revokes a refresh_token and makes the credentials void. @@ -162,6 +211,7 @@ def revoke(self, http): """ _abstract() + def apply(self, headers): """Add the authorization to the headers. @@ -185,12 +235,16 @@ def _to_json(self, strip): for member in strip: if member in d: del d[member] - if 'token_expiry' in d and isinstance(d['token_expiry'], datetime.datetime): + if (d.get('token_expiry') and + isinstance(d['token_expiry'], datetime.datetime)): d['token_expiry'] = d['token_expiry'].strftime(EXPIRY_FORMAT) # Add in information we will need later to reconsistitue this instance. d['_class'] = t.__name__ d['_module'] = t.__module__ - return simplejson.dumps(d) + for key, val in d.items(): + if isinstance(val, bytes): + d[key] = val.decode('utf-8') + return json.dumps(d) def to_json(self): """Creating a JSON representation of an instance of Credentials. @@ -213,14 +267,16 @@ def new_from_json(cls, s): An instance of the subclass of Credentials that was serialized with to_json(). """ - data = simplejson.loads(s) + if six.PY3 and isinstance(s, bytes): + s = s.decode('utf-8') + data = json.loads(s) # Find and call the right classmethod from_json() to restore the object. module = data['_module'] try: m = __import__(module) except ImportError: # In case there's an object from the old package structure, update it - module = module.replace('.apiclient', '') + module = module.replace('.googleapiclient', '') m = __import__(module) m = __import__(module, fromlist=module.split('.')[:-1]) @@ -229,13 +285,13 @@ def new_from_json(cls, s): return from_json(s) @classmethod - def from_json(cls, s): + def from_json(cls, unused_data): """Instantiate a Credentials object from a JSON description of it. The JSON should have been produced by calling .to_json() on the object. Args: - data: dict, A deserialized JSON object. + unused_data: dict, A deserialized JSON object. Returns: An instance of a Credentials subclass. @@ -357,8 +413,10 @@ def clean_headers(headers): """ clean = {} try: - for k, v in headers.iteritems(): - clean[str(k)] = str(v) + for k, v in six.iteritems(headers): + clean_k = k if isinstance(k, bytes) else str(k).encode('ascii') + clean_v = v if isinstance(v, bytes) else str(v).encode('ascii') + clean[clean_k] = clean_v except UnicodeEncodeError: raise NonAsciiHeaderError(k + ': ' + v) return clean @@ -374,11 +432,11 @@ def _update_query_params(uri, params): Returns: The same URI but with the new query parameters added. """ - parts = list(urlparse.urlparse(uri)) - query_params = dict(parse_qsl(parts[4])) # 4 is the index of the query part + parts = urllib.parse.urlparse(uri) + query_params = dict(urllib.parse.parse_qsl(parts.query)) query_params.update(params) - parts[4] = urllib.urlencode(query_params) - return urlparse.urlunparse(parts) + new_parts = parts._replace(query=urllib.parse.urlencode(query_params)) + return urllib.parse.urlunparse(new_parts) class OAuth2Credentials(Credentials): @@ -446,22 +504,23 @@ def authorize(self, http): it. Args: - http: An instance of httplib2.Http - or something that acts like it. + http: An instance of ``httplib2.Http`` or something that acts + like it. Returns: A modified instance of http that was passed in. - Example: + Example:: h = httplib2.Http() h = credentials.authorize(h) - You can't create a new OAuth subclass of httplib2.Authenication + You can't create a new OAuth subclass of httplib2.Authentication because it never gets passed the absolute URI, which is needed for signing. So instead we have to overload 'request' with a closure that adds in the Authorization header and then calls the original version of 'request()'. + """ request_orig = http.request @@ -474,10 +533,12 @@ def new_request(uri, method='GET', body=None, headers=None, logger.info('Attempting refresh to obtain initial access_token') self._refresh(request_orig) - # Modify the request headers to add the appropriate + # Clone and modify the request headers to add the appropriate # Authorization header. if headers is None: headers = {} + else: + headers = dict(headers) self.apply(headers) if self.user_agent is not None: @@ -490,7 +551,7 @@ def new_request(uri, method='GET', body=None, headers=None, redirections, connection_type) if resp.status in REFRESH_STATUS_CODES: - logger.info('Refreshing due to a %s' % str(resp.status)) + logger.info('Refreshing due to a %s', resp.status) self._refresh(request_orig) self.apply(headers) return request_orig(uri, method, body, clean_headers(headers), @@ -546,13 +607,15 @@ def from_json(cls, s): Returns: An instance of a Credentials subclass. """ - data = simplejson.loads(s) - if 'token_expiry' in data and not isinstance(data['token_expiry'], - datetime.datetime): + if six.PY3 and isinstance(s, bytes): + s = s.decode('utf-8') + data = json.loads(s) + if (data.get('token_expiry') and + not isinstance(data['token_expiry'], datetime.datetime)): try: data['token_expiry'] = datetime.datetime.strptime( data['token_expiry'], EXPIRY_FORMAT) - except: + except ValueError: data['token_expiry'] = None retval = cls( data['access_token'], @@ -587,11 +650,24 @@ def access_token_expired(self): return True return False + def get_access_token(self, http=None): + """Return the access token and its expiration information. + + If the token does not exist, get one. + If the token expired, refresh it. + """ + if not self.access_token or self.access_token_expired: + if not http: + http = httplib2.Http() + self.refresh(http) + return AccessTokenInfo(access_token=self.access_token, + expires_in=self._expires_in()) + def set_store(self, store): """Set the Storage for the credential. Args: - store: Storage, an implementation of Stroage object. + store: Storage, an implementation of Storage object. This is needed to store the latest access_token if it has expired and been refreshed. This implementation uses locking to check for updates before updating the @@ -599,6 +675,25 @@ def set_store(self, store): """ self.store = store + def _expires_in(self): + """Return the number of seconds until this token expires. + + If token_expiry is in the past, this method will return 0, meaning the + token has already expired. + If token_expiry is None, this method will return None. Note that returning + 0 in such a case would not be fair: the token may still be valid; + we just don't know anything about it. + """ + if self.token_expiry: + now = datetime.datetime.utcnow() + if self.token_expiry > now: + time_delta = self.token_expiry - now + # TODO(orestica): return time_delta.total_seconds() + # once dropping support for Python 2.6 + return time_delta.days * 86400 + time_delta.seconds + else: + return 0 + def _updateFromCredential(self, other): """Update this Credential from another instance.""" self.__dict__.update(other.__getstate__()) @@ -616,7 +711,7 @@ def __setstate__(self, state): def _generate_refresh_request_body(self): """Generate the body that will be used in the refresh request.""" - body = urllib.urlencode({ + body = urllib.parse.urlencode({ 'grant_type': 'refresh_token', 'client_id': self.client_id, 'client_secret': self.client_secret, @@ -680,9 +775,10 @@ def _do_refresh_request(self, http_request): logger.info('Refreshing access_token') resp, content = http_request( self.token_uri, method='POST', body=body, headers=headers) + if six.PY3 and isinstance(content, bytes): + content = content.decode('utf-8') if resp.status == 200: - # TODO(jcgregorio) Raise an error if loads fails? - d = simplejson.loads(content) + d = json.loads(content) self.token_response = d self.access_token = d['access_token'] self.refresh_token = d.get('refresh_token', self.refresh_token) @@ -691,35 +787,40 @@ def _do_refresh_request(self, http_request): seconds=int(d['expires_in'])) + datetime.datetime.utcnow() else: self.token_expiry = None + # On temporary refresh errors, the user does not actually have to + # re-authorize, so we unflag here. + self.invalid = False if self.store: self.store.locked_put(self) else: # An {'error':...} response body means the token is expired or revoked, # so we flag the credentials as such. - logger.info('Failed to retrieve access token: %s' % content) + logger.info('Failed to retrieve access token: %s', content) error_msg = 'Invalid response %s.' % resp['status'] try: - d = simplejson.loads(content) + d = json.loads(content) if 'error' in d: error_msg = d['error'] + if 'error_description' in d: + error_msg += ': ' + d['error_description'] self.invalid = True if self.store: self.store.locked_put(self) - except StandardError: + except (TypeError, ValueError): pass raise AccessTokenRefreshError(error_msg) def _revoke(self, http_request): - """Revokes the refresh_token and deletes the store if available. + """Revokes this credential and deletes the stored copy (if it exists). Args: http_request: callable, a callable that matches the method signature of httplib2.Http.request, used to make the revoke request. """ - self._do_revoke(http_request, self.refresh_token) + self._do_revoke(http_request, self.refresh_token or self.access_token) def _do_revoke(self, http_request, token): - """Revokes the credentials and deletes the store if available. + """Revokes this credential and deletes the stored copy (if it exists). Args: http_request: callable, a callable that matches the method signature of @@ -739,10 +840,10 @@ def _do_revoke(self, http_request, token): else: error_msg = 'Invalid response %s.' % resp.status try: - d = simplejson.loads(content) + d = json.loads(content) if 'error' in d: error_msg = d['error'] - except StandardError: + except (TypeError, ValueError): pass raise TokenRevokeError(error_msg) @@ -764,7 +865,8 @@ class AccessTokenCredentials(OAuth2Credentials): AccessTokenCredentials objects may be safely pickled and unpickled. - Usage: + Usage:: + credentials = AccessTokenCredentials('', 'my-user-agent/1.0') http = httplib2.Http() @@ -800,10 +902,12 @@ def __init__(self, access_token, user_agent, revoke_uri=None): @classmethod def from_json(cls, s): - data = simplejson.loads(s) + if six.PY3 and isinstance(s, bytes): + s = s.decode('utf-8') + data = json.loads(s) retval = AccessTokenCredentials( - data['access_token'], - data['user_agent']) + data['access_token'], + data['user_agent']) return retval def _refresh(self, http_request): @@ -820,7 +924,421 @@ def _revoke(self, http_request): self._do_revoke(http_request, self.access_token) -class AssertionCredentials(OAuth2Credentials): +def _detect_gce_environment(urlopen=None): + """Determine if the current environment is Compute Engine. + + Args: + urlopen: Optional argument. Function used to open a connection to a URL. + + Returns: + Boolean indicating whether or not the current environment is Google + Compute Engine. + """ + urlopen = urlopen or urllib.request.urlopen + # Note: the explicit `timeout` below is a workaround. The underlying + # issue is that resolving an unknown host on some networks will take + # 20-30 seconds; making this timeout short fixes the issue, but + # could lead to false negatives in the event that we are on GCE, but + # the metadata resolution was particularly slow. The latter case is + # "unlikely". + try: + response = urlopen('http://169.254.169.254/', timeout=1) + return response.info().get('Metadata-Flavor', '') == 'Google' + except socket.timeout: + logger.info('Timeout attempting to reach GCE metadata service.') + return False + except urllib.error.URLError as e: + if isinstance(getattr(e, 'reason', None), socket.timeout): + logger.info('Timeout attempting to reach GCE metadata service.') + return False + + +def _get_environment(urlopen=None): + """Detect the environment the code is being run on. + + Args: + urlopen: Optional argument. Function used to open a connection to a URL. + + Returns: + The value of SETTINGS.env_name after being set. If already + set, simply returns the value. + """ + if SETTINGS.env_name is not None: + return SETTINGS.env_name + + # None is an unset value, not the default. + SETTINGS.env_name = DEFAULT_ENV_NAME + + server_software = os.environ.get('SERVER_SOFTWARE', '') + if server_software.startswith('Google App Engine/'): + SETTINGS.env_name = 'GAE_PRODUCTION' + elif server_software.startswith('Development/'): + SETTINGS.env_name = 'GAE_LOCAL' + elif NO_GCE_CHECK != 'True' and _detect_gce_environment(urlopen=urlopen): + SETTINGS.env_name = 'GCE_PRODUCTION' + + return SETTINGS.env_name + + +class GoogleCredentials(OAuth2Credentials): + """Application Default Credentials for use in calling Google APIs. + + The Application Default Credentials are being constructed as a function of + the environment where the code is being run. + More details can be found on this page: + https://developers.google.com/accounts/docs/application-default-credentials + + Here is an example of how to use the Application Default Credentials for a + service that requires authentication: + + from googleapiclient.discovery import build + from oauth2client.client import GoogleCredentials + + credentials = GoogleCredentials.get_application_default() + service = build('compute', 'v1', credentials=credentials) + + PROJECT = 'bamboo-machine-422' + ZONE = 'us-central1-a' + request = service.instances().list(project=PROJECT, zone=ZONE) + response = request.execute() + + print(response) + """ + + def __init__(self, access_token, client_id, client_secret, refresh_token, + token_expiry, token_uri, user_agent, + revoke_uri=GOOGLE_REVOKE_URI): + """Create an instance of GoogleCredentials. + + This constructor is not usually called by the user, instead + GoogleCredentials objects are instantiated by + GoogleCredentials.from_stream() or + GoogleCredentials.get_application_default(). + + Args: + access_token: string, access token. + client_id: string, client identifier. + client_secret: string, client secret. + refresh_token: string, refresh token. + token_expiry: datetime, when the access_token expires. + token_uri: string, URI of token endpoint. + user_agent: string, The HTTP User-Agent to provide for this application. + revoke_uri: string, URI for revoke endpoint. + Defaults to GOOGLE_REVOKE_URI; a token can't be revoked if this is None. + """ + super(GoogleCredentials, self).__init__( + access_token, client_id, client_secret, refresh_token, token_expiry, + token_uri, user_agent, revoke_uri=revoke_uri) + + def create_scoped_required(self): + """Whether this Credentials object is scopeless. + + create_scoped(scopes) method needs to be called in order to create + a Credentials object for API calls. + """ + return False + + def create_scoped(self, scopes): + """Create a Credentials object for the given scopes. + + The Credentials type is preserved. + """ + return self + + @property + def serialization_data(self): + """Get the fields and their values identifying the current credentials.""" + return { + 'type': 'authorized_user', + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'refresh_token': self.refresh_token + } + + @staticmethod + def _implicit_credentials_from_gae(env_name=None): + """Attempts to get implicit credentials in Google App Engine env. + + If the current environment is not detected as App Engine, returns None, + indicating no Google App Engine credentials can be detected from the + current environment. + + Args: + env_name: String, indicating current environment. + + Returns: + None, if not in GAE, else an appengine.AppAssertionCredentials object. + """ + env_name = env_name or _get_environment() + if env_name not in ('GAE_PRODUCTION', 'GAE_LOCAL'): + return None + + return _get_application_default_credential_GAE() + + @staticmethod + def _implicit_credentials_from_gce(env_name=None): + """Attempts to get implicit credentials in Google Compute Engine env. + + If the current environment is not detected as Compute Engine, returns None, + indicating no Google Compute Engine credentials can be detected from the + current environment. + + Args: + env_name: String, indicating current environment. + + Returns: + None, if not in GCE, else a gce.AppAssertionCredentials object. + """ + env_name = env_name or _get_environment() + if env_name != 'GCE_PRODUCTION': + return None + + return _get_application_default_credential_GCE() + + @staticmethod + def _implicit_credentials_from_files(env_name=None): + """Attempts to get implicit credentials from local credential files. + + First checks if the environment variable GOOGLE_APPLICATION_CREDENTIALS + is set with a filename and then falls back to a configuration file (the + "well known" file) associated with the 'gcloud' command line tool. + + Args: + env_name: Unused argument. + + Returns: + Credentials object associated with the GOOGLE_APPLICATION_CREDENTIALS + file or the "well known" file if either exist. If neither file is + define, returns None, indicating no credentials from a file can + detected from the current environment. + """ + credentials_filename = _get_environment_variable_file() + if not credentials_filename: + credentials_filename = _get_well_known_file() + if os.path.isfile(credentials_filename): + extra_help = (' (produced automatically when running' + ' "gcloud auth login" command)') + else: + credentials_filename = None + else: + extra_help = (' (pointed to by ' + GOOGLE_APPLICATION_CREDENTIALS + + ' environment variable)') + + if not credentials_filename: + return + + try: + return _get_application_default_credential_from_file(credentials_filename) + except (ApplicationDefaultCredentialsError, ValueError) as error: + _raise_exception_for_reading_json(credentials_filename, extra_help, error) + + @classmethod + def _get_implicit_credentials(cls): + """Gets credentials implicitly from the environment. + + Checks environment in order of precedence: + - Google App Engine (production and testing) + - Environment variable GOOGLE_APPLICATION_CREDENTIALS pointing to + a file with stored credentials information. + - Stored "well known" file associated with `gcloud` command line tool. + - Google Compute Engine production environment. + + Exceptions: + ApplicationDefaultCredentialsError: raised when the credentials fail + to be retrieved. + """ + env_name = _get_environment() + + # Environ checks (in order). Assumes each checker takes `env_name` + # as a kwarg. + environ_checkers = [ + cls._implicit_credentials_from_gae, + cls._implicit_credentials_from_files, + cls._implicit_credentials_from_gce, + ] + + for checker in environ_checkers: + credentials = checker(env_name=env_name) + if credentials is not None: + return credentials + + # If no credentials, fail. + raise ApplicationDefaultCredentialsError(ADC_HELP_MSG) + + @staticmethod + def get_application_default(): + """Get the Application Default Credentials for the current environment. + + Exceptions: + ApplicationDefaultCredentialsError: raised when the credentials fail + to be retrieved. + """ + return GoogleCredentials._get_implicit_credentials() + + @staticmethod + def from_stream(credential_filename): + """Create a Credentials object by reading the information from a given file. + + It returns an object of type GoogleCredentials. + + Args: + credential_filename: the path to the file from where the credentials + are to be read + + Exceptions: + ApplicationDefaultCredentialsError: raised when the credentials fail + to be retrieved. + """ + + if credential_filename and os.path.isfile(credential_filename): + try: + return _get_application_default_credential_from_file( + credential_filename) + except (ApplicationDefaultCredentialsError, ValueError) as error: + extra_help = ' (provided as parameter to the from_stream() method)' + _raise_exception_for_reading_json(credential_filename, + extra_help, + error) + else: + raise ApplicationDefaultCredentialsError( + 'The parameter passed to the from_stream() ' + 'method should point to a file.') + + +def save_to_well_known_file(credentials, well_known_file=None): + """Save the provided GoogleCredentials to the well known file. + + Args: + credentials: + the credentials to be saved to the well known file; + it should be an instance of GoogleCredentials + well_known_file: + the name of the file where the credentials are to be saved; + this parameter is supposed to be used for testing only + """ + # TODO(orestica): move this method to tools.py + # once the argparse import gets fixed (it is not present in Python 2.6) + + if well_known_file is None: + well_known_file = _get_well_known_file() + + credentials_data = credentials.serialization_data + + with open(well_known_file, 'w') as f: + json.dump(credentials_data, f, sort_keys=True, indent=2, separators=(',', ': ')) + + +def _get_environment_variable_file(): + application_default_credential_filename = ( + os.environ.get(GOOGLE_APPLICATION_CREDENTIALS, + None)) + + if application_default_credential_filename: + if os.path.isfile(application_default_credential_filename): + return application_default_credential_filename + else: + raise ApplicationDefaultCredentialsError( + 'File ' + application_default_credential_filename + ' (pointed by ' + + GOOGLE_APPLICATION_CREDENTIALS + + ' environment variable) does not exist!') + + +def _get_well_known_file(): + """Get the well known file produced by command 'gcloud auth login'.""" + # TODO(orestica): Revisit this method once gcloud provides a better way + # of pinpointing the exact location of the file. + + WELL_KNOWN_CREDENTIALS_FILE = 'application_default_credentials.json' + CLOUDSDK_CONFIG_DIRECTORY = 'gcloud' + + if os.name == 'nt': + try: + default_config_path = os.path.join(os.environ['APPDATA'], + CLOUDSDK_CONFIG_DIRECTORY) + except KeyError: + # This should never happen unless someone is really messing with things. + drive = os.environ.get('SystemDrive', 'C:') + default_config_path = os.path.join(drive, '\\', CLOUDSDK_CONFIG_DIRECTORY) + else: + default_config_path = os.path.join(os.path.expanduser('~'), + '.config', + CLOUDSDK_CONFIG_DIRECTORY) + + default_config_path = os.path.join(default_config_path, + WELL_KNOWN_CREDENTIALS_FILE) + + return default_config_path + + +def _get_application_default_credential_from_file(filename): + """Build the Application Default Credentials from file.""" + + from oauth2client import service_account + + # read the credentials from the file + with open(filename) as file_obj: + client_credentials = json.load(file_obj) + + credentials_type = client_credentials.get('type') + if credentials_type == AUTHORIZED_USER: + required_fields = set(['client_id', 'client_secret', 'refresh_token']) + elif credentials_type == SERVICE_ACCOUNT: + required_fields = set(['client_id', 'client_email', 'private_key_id', + 'private_key']) + else: + raise ApplicationDefaultCredentialsError( + "'type' field should be defined (and have one of the '" + + AUTHORIZED_USER + "' or '" + SERVICE_ACCOUNT + "' values)") + + missing_fields = required_fields.difference(client_credentials.keys()) + + if missing_fields: + _raise_exception_for_missing_fields(missing_fields) + + if client_credentials['type'] == AUTHORIZED_USER: + return GoogleCredentials( + access_token=None, + client_id=client_credentials['client_id'], + client_secret=client_credentials['client_secret'], + refresh_token=client_credentials['refresh_token'], + token_expiry=None, + token_uri=GOOGLE_TOKEN_URI, + user_agent='Python client library') + else: # client_credentials['type'] == SERVICE_ACCOUNT + return service_account._ServiceAccountCredentials( + service_account_id=client_credentials['client_id'], + service_account_email=client_credentials['client_email'], + private_key_id=client_credentials['private_key_id'], + private_key_pkcs8_text=client_credentials['private_key'], + scopes=[]) + + +def _raise_exception_for_missing_fields(missing_fields): + raise ApplicationDefaultCredentialsError( + 'The following field(s) must be defined: ' + ', '.join(missing_fields)) + + +def _raise_exception_for_reading_json(credential_file, + extra_help, + error): + raise ApplicationDefaultCredentialsError( + 'An error was encountered while reading json file: '+ + credential_file + extra_help + ': ' + str(error)) + + +def _get_application_default_credential_GAE(): + from oauth2client.appengine import AppAssertionCredentials + + return AppAssertionCredentials([]) + + +def _get_application_default_credential_GCE(): + from oauth2client.gce import AppAssertionCredentials + + return AppAssertionCredentials([]) + + +class AssertionCredentials(GoogleCredentials): """Abstract Credentials object used for OAuth 2.0 assertion grants. This credential does not require a flow to instantiate because it @@ -860,7 +1378,7 @@ def __init__(self, assertion_type, user_agent=None, def _generate_refresh_request_body(self): assertion = self._generate_assertion() - body = urllib.urlencode({ + body = urllib.parse.urlencode({ 'assertion': assertion, 'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer', }) @@ -883,141 +1401,157 @@ def _revoke(self, http_request): self._do_revoke(http_request, self.access_token) -if HAS_CRYPTO: - # PyOpenSSL and PyCrypto are not prerequisites for oauth2client, so if it is - # missing then don't create the SignedJwtAssertionCredentials or the - # verify_id_token() method. +def _RequireCryptoOrDie(): + """Ensure we have a crypto library, or throw CryptoUnavailableError. + + The oauth2client.crypt module requires either PyCrypto or PyOpenSSL + to be available in order to function, but these are optional + dependencies. + """ + if not HAS_CRYPTO: + raise CryptoUnavailableError('No crypto library available') + + +class SignedJwtAssertionCredentials(AssertionCredentials): + """Credentials object used for OAuth 2.0 Signed JWT assertion grants. - class SignedJwtAssertionCredentials(AssertionCredentials): - """Credentials object used for OAuth 2.0 Signed JWT assertion grants. + This credential does not require a flow to instantiate because it + represents a two legged flow, and therefore has all of the required + information to generate and refresh its own access tokens. + + SignedJwtAssertionCredentials requires either PyOpenSSL, or PyCrypto + 2.6 or later. For App Engine you may also consider using + AppAssertionCredentials. + """ - This credential does not require a flow to instantiate because it represents - a two legged flow, and therefore has all of the required information to - generate and refresh its own access tokens. + MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds - SignedJwtAssertionCredentials requires either PyOpenSSL, or PyCrypto 2.6 or - later. For App Engine you may also consider using AppAssertionCredentials. + @util.positional(4) + def __init__(self, + service_account_name, + private_key, + scope, + private_key_password='notasecret', + user_agent=None, + token_uri=GOOGLE_TOKEN_URI, + revoke_uri=GOOGLE_REVOKE_URI, + **kwargs): + """Constructor for SignedJwtAssertionCredentials. + + Args: + service_account_name: string, id for account, usually an email address. + private_key: string, private key in PKCS12 or PEM format. + scope: string or iterable of strings, scope(s) of the credentials being + requested. + private_key_password: string, password for private_key, unused if + private_key is in PEM format. + user_agent: string, HTTP User-Agent to provide for this application. + token_uri: string, URI for token endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 provider can be used. + revoke_uri: string, URI for revoke endpoint. + kwargs: kwargs, Additional parameters to add to the JWT token, for + example sub=joe@xample.org. + + Raises: + CryptoUnavailableError if no crypto library is available. """ + _RequireCryptoOrDie() + super(SignedJwtAssertionCredentials, self).__init__( + None, + user_agent=user_agent, + token_uri=token_uri, + revoke_uri=revoke_uri, + ) - MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds + self.scope = util.scopes_to_string(scope) - @util.positional(4) - def __init__(self, - service_account_name, - private_key, - scope, - private_key_password='notasecret', - user_agent=None, - token_uri=GOOGLE_TOKEN_URI, - revoke_uri=GOOGLE_REVOKE_URI, - **kwargs): - """Constructor for SignedJwtAssertionCredentials. - - Args: - service_account_name: string, id for account, usually an email address. - private_key: string, private key in PKCS12 or PEM format. - scope: string or iterable of strings, scope(s) of the credentials being - requested. - private_key_password: string, password for private_key, unused if - private_key is in PEM format. - user_agent: string, HTTP User-Agent to provide for this application. - token_uri: string, URI for token endpoint. For convenience - defaults to Google's endpoints but any OAuth 2.0 provider can be used. - revoke_uri: string, URI for revoke endpoint. - kwargs: kwargs, Additional parameters to add to the JWT token, for - example prn=joe@xample.org.""" - - super(SignedJwtAssertionCredentials, self).__init__( - None, - user_agent=user_agent, - token_uri=token_uri, - revoke_uri=revoke_uri, - ) - - self.scope = util.scopes_to_string(scope) - - # Keep base64 encoded so it can be stored in JSON. - self.private_key = base64.b64encode(private_key) - - self.private_key_password = private_key_password - self.service_account_name = service_account_name - self.kwargs = kwargs - - @classmethod - def from_json(cls, s): - data = simplejson.loads(s) - retval = SignedJwtAssertionCredentials( - data['service_account_name'], - base64.b64decode(data['private_key']), - data['scope'], - private_key_password=data['private_key_password'], - user_agent=data['user_agent'], - token_uri=data['token_uri'], - **data['kwargs'] - ) - retval.invalid = data['invalid'] - retval.access_token = data['access_token'] - return retval - - def _generate_assertion(self): - """Generate the assertion that will be used in the request.""" - now = long(time.time()) - payload = { - 'aud': self.token_uri, - 'scope': self.scope, - 'iat': now, - 'exp': now + SignedJwtAssertionCredentials.MAX_TOKEN_LIFETIME_SECS, - 'iss': self.service_account_name - } - payload.update(self.kwargs) - logger.debug(str(payload)) + # Keep base64 encoded so it can be stored in JSON. + self.private_key = base64.b64encode(private_key) + if isinstance(self.private_key, six.text_type): + self.private_key = self.private_key.encode('utf-8') - private_key = base64.b64decode(self.private_key) - return crypt.make_signed_jwt(crypt.Signer.from_string( - private_key, self.private_key_password), payload) + self.private_key_password = private_key_password + self.service_account_name = service_account_name + self.kwargs = kwargs - # Only used in verify_id_token(), which is always calling to the same URI - # for the certs. - _cached_http = httplib2.Http(MemoryCache()) + @classmethod + def from_json(cls, s): + data = json.loads(s) + retval = SignedJwtAssertionCredentials( + data['service_account_name'], + base64.b64decode(data['private_key']), + data['scope'], + private_key_password=data['private_key_password'], + user_agent=data['user_agent'], + token_uri=data['token_uri'], + **data['kwargs'] + ) + retval.invalid = data['invalid'] + retval.access_token = data['access_token'] + return retval - @util.positional(2) - def verify_id_token(id_token, audience, http=None, - cert_uri=ID_TOKEN_VERIFICATON_CERTS): - """Verifies a signed JWT id_token. + def _generate_assertion(self): + """Generate the assertion that will be used in the request.""" + now = int(time.time()) + payload = { + 'aud': self.token_uri, + 'scope': self.scope, + 'iat': now, + 'exp': now + SignedJwtAssertionCredentials.MAX_TOKEN_LIFETIME_SECS, + 'iss': self.service_account_name + } + payload.update(self.kwargs) + logger.debug(str(payload)) - This function requires PyOpenSSL and because of that it does not work on - App Engine. + private_key = base64.b64decode(self.private_key) + return crypt.make_signed_jwt(crypt.Signer.from_string( + private_key, self.private_key_password), payload) - Args: - id_token: string, A Signed JWT. - audience: string, The audience 'aud' that the token should be for. - http: httplib2.Http, instance to use to make the HTTP request. Callers - should supply an instance that has caching enabled. - cert_uri: string, URI of the certificates in JSON format to - verify the JWT against. +# Only used in verify_id_token(), which is always calling to the same URI +# for the certs. +_cached_http = httplib2.Http(MemoryCache()) - Returns: - The deserialized JSON in the JWT. +@util.positional(2) +def verify_id_token(id_token, audience, http=None, + cert_uri=ID_TOKEN_VERIFICATION_CERTS): + """Verifies a signed JWT id_token. - Raises: - oauth2client.crypt.AppIdentityError if the JWT fails to verify. - """ - if http is None: - http = _cached_http + This function requires PyOpenSSL and because of that it does not work on + App Engine. - resp, content = http.request(cert_uri) + Args: + id_token: string, A Signed JWT. + audience: string, The audience 'aud' that the token should be for. + http: httplib2.Http, instance to use to make the HTTP request. Callers + should supply an instance that has caching enabled. + cert_uri: string, URI of the certificates in JSON format to + verify the JWT against. - if resp.status == 200: - certs = simplejson.loads(content) - return crypt.verify_signed_jwt_with_certs(id_token, certs, audience) - else: - raise VerifyJwtTokenError('Status code: %d' % resp.status) + Returns: + The deserialized JSON in the JWT. + + Raises: + oauth2client.crypt.AppIdentityError: if the JWT fails to verify. + CryptoUnavailableError: if no crypto library is available. + """ + _RequireCryptoOrDie() + if http is None: + http = _cached_http + + resp, content = http.request(cert_uri) + + if resp.status == 200: + certs = json.loads(content.decode('utf-8')) + return crypt.verify_signed_jwt_with_certs(id_token, certs, audience) + else: + raise VerifyJwtTokenError('Status code: %d' % resp.status) def _urlsafe_b64decode(b64string): # Guard against unicode strings, which base64 can't handle. - b64string = b64string.encode('ascii') - padded = b64string + '=' * (4 - len(b64string) % 4) + if isinstance(b64string, six.text_type): + b64string = b64string.encode('ascii') + padded = b64string + b'=' * (4 - len(b64string) % 4) return base64.urlsafe_b64decode(padded) @@ -1027,18 +1561,21 @@ def _extract_id_token(id_token): Does the extraction w/o checking the signature. Args: - id_token: string, OAuth 2.0 id_token. + id_token: string or bytestring, OAuth 2.0 id_token. Returns: object, The deserialized JSON payload. """ - segments = id_token.split('.') + if type(id_token) == bytes: + segments = id_token.split(b'.') + else: + segments = id_token.split(u'.') - if (len(segments) != 3): + if len(segments) != 3: raise VerifyJwtTokenError( - 'Wrong number of segments in token: %s' % id_token) + 'Wrong number of segments in token: %s' % id_token) - return simplejson.loads(_urlsafe_b64decode(segments[1])) + return json.loads(_urlsafe_b64decode(segments[1]).decode('utf-8')) def _parse_exchange_token_response(content): @@ -1056,11 +1593,12 @@ def _parse_exchange_token_response(content): """ resp = {} try: - resp = simplejson.loads(content) - except StandardError: + resp = json.loads(content.decode('utf-8')) + except Exception: # different JSON libs raise different exceptions, # so we just do a catch-all here - resp = dict(parse_qsl(content)) + content = content.decode('utf-8') + resp = dict(urllib.parse.parse_qsl(content)) # some providers respond with 'expires', others with 'expires_in' if resp and 'expires' in resp: @@ -1074,14 +1612,15 @@ def credentials_from_code(client_id, client_secret, scope, code, redirect_uri='postmessage', http=None, user_agent=None, token_uri=GOOGLE_TOKEN_URI, auth_uri=GOOGLE_AUTH_URI, - revoke_uri=GOOGLE_REVOKE_URI): + revoke_uri=GOOGLE_REVOKE_URI, + device_uri=GOOGLE_DEVICE_URI): """Exchanges an authorization code for an OAuth2Credentials object. Args: client_id: string, client identifier. client_secret: string, client secret. scope: string or iterable of strings, scope(s) to request. - code: string, An authroization code, most likely passed down from + code: string, An authorization code, most likely passed down from the client redirect_uri: string, this is generally set to 'postmessage' to match the redirect_uri that the client specified @@ -1092,6 +1631,8 @@ def credentials_from_code(client_id, client_secret, scope, code, defaults to Google's endpoints but any OAuth 2.0 provider can be used. revoke_uri: string, URI for revoke endpoint. For convenience defaults to Google's endpoints but any OAuth 2.0 provider can be used. + device_uri: string, URI for device authorization endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 provider can be used. Returns: An OAuth2Credentials object. @@ -1103,7 +1644,7 @@ def credentials_from_code(client_id, client_secret, scope, code, flow = OAuth2WebServerFlow(client_id, client_secret, scope, redirect_uri=redirect_uri, user_agent=user_agent, auth_uri=auth_uri, token_uri=token_uri, - revoke_uri=revoke_uri) + revoke_uri=revoke_uri, device_uri=device_uri) credentials = flow.step2_exchange(code, http=http) return credentials @@ -1114,7 +1655,8 @@ def credentials_from_clientsecrets_and_code(filename, scope, code, message = None, redirect_uri='postmessage', http=None, - cache=None): + cache=None, + device_uri=None): """Returns OAuth2Credentials from a clientsecrets file and an auth code. Will create the right kind of Flow based on the contents of the clientsecrets @@ -1134,6 +1676,7 @@ def credentials_from_clientsecrets_and_code(filename, scope, code, http: httplib2.Http, optional http instance to use to do the fetch cache: An optional cache service client that implements get() and set() methods. See clientsecrets.loadfile() for details. + device_uri: string, OAuth 2.0 device authorization endpoint Returns: An OAuth2Credentials object. @@ -1146,11 +1689,49 @@ def credentials_from_clientsecrets_and_code(filename, scope, code, invalid. """ flow = flow_from_clientsecrets(filename, scope, message=message, cache=cache, - redirect_uri=redirect_uri) + redirect_uri=redirect_uri, + device_uri=device_uri) credentials = flow.step2_exchange(code, http=http) return credentials +class DeviceFlowInfo(collections.namedtuple('DeviceFlowInfo', ( + 'device_code', 'user_code', 'interval', 'verification_url', + 'user_code_expiry'))): + """Intermediate information the OAuth2 for devices flow.""" + + @classmethod + def FromResponse(cls, response): + """Create a DeviceFlowInfo from a server response. + + The response should be a dict containing entries as described here: + + http://tools.ietf.org/html/draft-ietf-oauth-v2-05#section-3.7.1 + """ + # device_code, user_code, and verification_url are required. + kwargs = { + 'device_code': response['device_code'], + 'user_code': response['user_code'], + } + # The response may list the verification address as either + # verification_url or verification_uri, so we check for both. + verification_url = response.get( + 'verification_url', response.get('verification_uri')) + if verification_url is None: + raise OAuth2DeviceCodeError( + 'No verification_url provided in server response') + kwargs['verification_url'] = verification_url + # expires_in and interval are optional. + kwargs.update({ + 'interval': response.get('interval'), + 'user_code_expiry': None, + }) + if 'expires_in' in response: + kwargs['user_code_expiry'] = datetime.datetime.now() + datetime.timedelta( + seconds=int(response['expires_in'])) + + return cls(**kwargs) + class OAuth2WebServerFlow(Flow): """Does the Web Server Flow for OAuth 2.0. @@ -1165,6 +1746,7 @@ def __init__(self, client_id, client_secret, scope, token_uri=GOOGLE_TOKEN_URI, revoke_uri=GOOGLE_REVOKE_URI, login_hint=None, + device_uri=GOOGLE_DEVICE_URI, **kwargs): """Constructor for OAuth2WebServerFlow. @@ -1190,6 +1772,8 @@ def __init__(self, client_id, client_secret, scope, login_hint: string, Either an email address or domain. Passing this hint will either pre-fill the email box on the sign-in form or select the proper multi-login session, thereby simplifying the login flow. + device_uri: string, URI for device authorization endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 provider can be used. **kwargs: dict, The keyword arguments are all optional and required parameters for the OAuth calls. """ @@ -1202,6 +1786,7 @@ def __init__(self, client_id, client_secret, scope, self.auth_uri = auth_uri self.token_uri = token_uri self.revoke_uri = revoke_uri + self.device_uri = device_uri self.params = { 'access_type': 'offline', 'response_type': 'code', @@ -1222,8 +1807,9 @@ def step1_get_authorize_url(self, redirect_uri=None): A URI as a string to redirect the user to begin the authorization flow. """ if redirect_uri is not None: - logger.warning(('The redirect_uri parameter for' - 'OAuth2WebServerFlow.step1_get_authorize_url is deprecated. Please' + logger.warning(( + 'The redirect_uri parameter for ' + 'OAuth2WebServerFlow.step1_get_authorize_url is deprecated. Please ' 'move to passing the redirect_uri in via the constructor.')) self.redirect_uri = redirect_uri @@ -1240,42 +1826,102 @@ def step1_get_authorize_url(self, redirect_uri=None): query_params.update(self.params) return _update_query_params(self.auth_uri, query_params) + @util.positional(1) + def step1_get_device_and_user_codes(self, http=None): + """Returns a user code and the verification URL where to enter it + + Returns: + A user code as a string for the user to authorize the application + An URL as a string where the user has to enter the code + """ + if self.device_uri is None: + raise ValueError('The value of device_uri must not be None.') + + body = urllib.parse.urlencode({ + 'client_id': self.client_id, + 'scope': self.scope, + }) + headers = { + 'content-type': 'application/x-www-form-urlencoded', + } + + if self.user_agent is not None: + headers['user-agent'] = self.user_agent + + if http is None: + http = httplib2.Http() + + resp, content = http.request(self.device_uri, method='POST', body=body, + headers=headers) + if resp.status == 200: + try: + flow_info = json.loads(content) + except ValueError as e: + raise OAuth2DeviceCodeError( + 'Could not parse server response as JSON: "%s", error: "%s"' % ( + content, e)) + return DeviceFlowInfo.FromResponse(flow_info) + else: + error_msg = 'Invalid response %s.' % resp.status + try: + d = json.loads(content) + if 'error' in d: + error_msg += ' Error: %s' % d['error'] + except ValueError: + # Couldn't decode a JSON response, stick with the default message. + pass + raise OAuth2DeviceCodeError(error_msg) + @util.positional(2) - def step2_exchange(self, code, http=None): - """Exhanges a code for OAuth2Credentials. + def step2_exchange(self, code=None, http=None, device_flow_info=None): + """Exchanges a code for OAuth2Credentials. Args: - code: string or dict, either the code as a string, or a dictionary - of the query parameters to the redirect_uri, which contains - the code. - http: httplib2.Http, optional http instance to use to do the fetch + + code: string, a dict-like object, or None. For a non-device + flow, this is either the response code as a string, or a + dictionary of query parameters to the redirect_uri. For a + device flow, this should be None. + http: httplib2.Http, optional http instance to use when fetching + credentials. + device_flow_info: DeviceFlowInfo, return value from step1 in the + case of a device flow. Returns: An OAuth2Credentials object that can be used to authorize requests. Raises: - FlowExchangeError if a problem occured exchanging the code for a - refresh_token. - """ + FlowExchangeError: if a problem occurred exchanging the code for a + refresh_token. + ValueError: if code and device_flow_info are both provided or both + missing. - if not (isinstance(code, str) or isinstance(code, unicode)): + """ + if code is None and device_flow_info is None: + raise ValueError('No code or device_flow_info provided.') + if code is not None and device_flow_info is not None: + raise ValueError('Cannot provide both code and device_flow_info.') + + if code is None: + code = device_flow_info.device_code + elif not isinstance(code, six.string_types): if 'code' not in code: - if 'error' in code: - error_msg = code['error'] - else: - error_msg = 'No code was supplied in the query parameters.' - raise FlowExchangeError(error_msg) - else: - code = code['code'] + raise FlowExchangeError(code.get( + 'error', 'No code was supplied in the query parameters.')) + code = code['code'] - body = urllib.urlencode({ - 'grant_type': 'authorization_code', + post_data = { 'client_id': self.client_id, 'client_secret': self.client_secret, 'code': code, - 'redirect_uri': self.redirect_uri, 'scope': self.scope, - }) + } + if device_flow_info is not None: + post_data['grant_type'] = 'http://oauth.net/grant_type/device/1.0' + else: + post_data['grant_type'] = 'authorization_code' + post_data['redirect_uri'] = self.redirect_uri + body = urllib.parse.urlencode(post_data) headers = { 'content-type': 'application/x-www-form-urlencoded', } @@ -1292,26 +1938,31 @@ def step2_exchange(self, code, http=None): if resp.status == 200 and 'access_token' in d: access_token = d['access_token'] refresh_token = d.get('refresh_token', None) + if not refresh_token: + logger.info( + 'Received token response with no refresh_token. Consider ' + "reauthenticating with approval_prompt='force'.") token_expiry = None if 'expires_in' in d: token_expiry = datetime.datetime.utcnow() + datetime.timedelta( seconds=int(d['expires_in'])) + extracted_id_token = None if 'id_token' in d: - d['id_token'] = _extract_id_token(d['id_token']) + extracted_id_token = _extract_id_token(d['id_token']) logger.info('Successfully retrieved access token') return OAuth2Credentials(access_token, self.client_id, self.client_secret, refresh_token, token_expiry, self.token_uri, self.user_agent, revoke_uri=self.revoke_uri, - id_token=d.get('id_token', None), + id_token=extracted_id_token, token_response=d) else: - logger.info('Failed to retrieve access token: %s' % content) + logger.info('Failed to retrieve access token: %s', content) if 'error' in d: # you never know what those providers got to say - error_msg = unicode(d['error']) + error_msg = str(d['error']) + str(d.get('error_description', '')) else: error_msg = 'Invalid response: %s.' % str(resp.status) raise FlowExchangeError(error_msg) @@ -1319,7 +1970,8 @@ def step2_exchange(self, code, http=None): @util.positional(2) def flow_from_clientsecrets(filename, scope, redirect_uri=None, - message=None, cache=None, login_hint=None): + message=None, cache=None, login_hint=None, + device_uri=None): """Create a Flow from a clientsecrets file. Will create the right kind of Flow based on the contents of the clientsecrets @@ -1340,6 +1992,9 @@ def flow_from_clientsecrets(filename, scope, redirect_uri=None, login_hint: string, Either an email address or domain. Passing this hint will either pre-fill the email box on the sign-in form or select the proper multi-login session, thereby simplifying the login flow. + device_uri: string, URI for device authorization endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 provider can be used. + Returns: A Flow object. @@ -1360,6 +2015,8 @@ def flow_from_clientsecrets(filename, scope, redirect_uri=None, revoke_uri = client_info.get('revoke_uri') if revoke_uri is not None: constructor_kwargs['revoke_uri'] = revoke_uri + if device_uri is not None: + constructor_kwargs['device_uri'] = device_uri return OAuth2WebServerFlow( client_info['client_id'], client_info['client_secret'], scope, **constructor_kwargs) diff --git a/oauth2client/clientsecrets.py b/oauth2client/clientsecrets.py index ac99aae..08a1702 100644 --- a/oauth2client/clientsecrets.py +++ b/oauth2client/clientsecrets.py @@ -1,4 +1,4 @@ -# Copyright (C) 2011 Google Inc. +# Copyright 2014 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,8 +20,9 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)' +import json +import six -from anyjson import simplejson # Properties that make a client_secrets.json file valid. TYPE_WEB = 'web' @@ -68,11 +69,21 @@ class InvalidClientSecretsError(Error): def _validate_clientsecrets(obj): - if obj is None or len(obj) != 1: - raise InvalidClientSecretsError('Invalid file format.') - client_type = obj.keys()[0] - if client_type not in VALID_CLIENT.keys(): - raise InvalidClientSecretsError('Unknown client type: %s.' % client_type) + _INVALID_FILE_FORMAT_MSG = ( + 'Invalid file format. See ' + 'https://developers.google.com/api-client-library/' + 'python/guide/aaa_client_secrets') + + if obj is None: + raise InvalidClientSecretsError(_INVALID_FILE_FORMAT_MSG) + if len(obj) != 1: + raise InvalidClientSecretsError( + _INVALID_FILE_FORMAT_MSG + ' ' + 'Expected a JSON object with a single property for a "web" or ' + '"installed" application') + client_type = tuple(obj)[0] + if client_type not in VALID_CLIENT: + raise InvalidClientSecretsError('Unknown client type: %s.' % (client_type,)) client_info = obj[client_type] for prop_name in VALID_CLIENT[client_type]['required']: if prop_name not in client_info: @@ -87,22 +98,19 @@ def _validate_clientsecrets(obj): def load(fp): - obj = simplejson.load(fp) + obj = json.load(fp) return _validate_clientsecrets(obj) def loads(s): - obj = simplejson.loads(s) + obj = json.loads(s) return _validate_clientsecrets(obj) def _loadfile(filename): try: - fp = file(filename, 'r') - try: - obj = simplejson.load(fp) - finally: - fp.close() + with open(filename, 'r') as fp: + obj = json.load(fp) except IOError: raise InvalidClientSecretsError('File not found: "%s"' % filename) return _validate_clientsecrets(obj) @@ -114,10 +122,12 @@ def loadfile(filename, cache=None): Typical cache storage would be App Engine memcache service, but you can pass in any other cache client that implements these methods: - - get(key, namespace=ns) - - set(key, value, namespace=ns) - Usage: + * ``get(key, namespace=ns)`` + * ``set(key, value, namespace=ns)`` + + Usage:: + # without caching client_type, client_info = loadfile('secrets.json') # using App Engine memcache service @@ -150,4 +160,4 @@ def loadfile(filename, cache=None): obj = {client_type: client_info} cache.set(filename, obj, namespace=_SECRET_NAMESPACE) - return obj.iteritems().next() + return next(six.iteritems(obj)) diff --git a/oauth2client/crypt.py b/oauth2client/crypt.py index 2d31815..381f389 100644 --- a/oauth2client/crypt.py +++ b/oauth2client/crypt.py @@ -1,7 +1,6 @@ -#!/usr/bin/python2.4 # -*- coding: utf-8 -*- # -# Copyright (C) 2011 Google Inc. +# Copyright 2014 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,13 +13,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Crypto-related routines for oauth2client.""" import base64 -import hashlib +import json import logging +import sys import time -from anyjson import simplejson +import six CLOCK_SKEW_SECS = 300 # 5 minutes in seconds @@ -38,7 +39,6 @@ class AppIdentityError(Exception): try: from OpenSSL import crypto - class OpenSSLVerifier(object): """Verifies the signature on a message.""" @@ -62,6 +62,8 @@ def verify(self, message, signature): key that this object was constructed with. """ try: + if isinstance(message, six.text_type): + message = message.encode('utf-8') crypto.verify(self._pubkey, signature, message, 'sha256') return True except: @@ -104,15 +106,17 @@ def sign(self, message): """Signs a message. Args: - message: string, Message to be signed. + message: bytes, Message to be signed. Returns: string, The signature of the message for the given key. """ + if isinstance(message, six.text_type): + message = message.encode('utf-8') return crypto.sign(self._key, message, 'sha256') @staticmethod - def from_string(key, password='notasecret'): + def from_string(key, password=b'notasecret'): """Construct a Signer instance from a string. Args: @@ -125,21 +129,45 @@ def from_string(key, password='notasecret'): Raises: OpenSSL.crypto.Error if the key can't be parsed. """ - if key.startswith('-----BEGIN '): - pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key) + parsed_pem_key = _parse_pem_key(key) + if parsed_pem_key: + pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, parsed_pem_key) else: + if isinstance(password, six.text_type): + password = password.encode('utf-8') pkey = crypto.load_pkcs12(key, password).get_privatekey() return OpenSSLSigner(pkey) + + def pkcs12_key_as_pem(private_key_text, private_key_password): + """Convert the contents of a PKCS12 key to PEM using OpenSSL. + + Args: + private_key_text: String. Private key. + private_key_password: String. Password for PKCS12. + + Returns: + String. PEM contents of ``private_key_text``. + """ + decoded_body = base64.b64decode(private_key_text) + if isinstance(private_key_password, six.string_types): + private_key_password = private_key_password.encode('ascii') + + pkcs12 = crypto.load_pkcs12(decoded_body, private_key_password) + return crypto.dump_privatekey(crypto.FILETYPE_PEM, + pkcs12.get_privatekey()) except ImportError: OpenSSLVerifier = None OpenSSLSigner = None + def pkcs12_key_as_pem(*args, **kwargs): + raise NotImplementedError('pkcs12_key_as_pem requires OpenSSL.') try: from Crypto.PublicKey import RSA from Crypto.Hash import SHA256 from Crypto.Signature import PKCS1_v1_5 + from Crypto.Util.asn1 import DerSequence class PyCryptoVerifier(object): @@ -181,14 +209,17 @@ def from_string(key_pem, is_x509_cert): Returns: Verifier instance. - - Raises: - NotImplementedError if is_x509_cert is true. """ if is_x509_cert: - raise NotImplementedError( - 'X509 certs are not supported by the PyCrypto library. ' - 'Try using PyOpenSSL if native code is an option.') + if isinstance(key_pem, six.text_type): + key_pem = key_pem.encode('ascii') + pemLines = key_pem.replace(b' ', b'').split() + certDer = _urlsafe_b64decode(b''.join(pemLines[1:-1])) + certSeq = DerSequence() + certSeq.decode(certDer) + tbsSeq = DerSequence() + tbsSeq.decode(certSeq[0]) + pubkey = RSA.importKey(tbsSeq[6]) else: pubkey = RSA.importKey(key_pem) return PyCryptoVerifier(pubkey) @@ -214,6 +245,8 @@ def sign(self, message): Returns: string, The signature of the message for the given key. """ + if isinstance(message, six.text_type): + message = message.encode('utf-8') return PKCS1_v1_5.new(self._key).sign(SHA256.new(message)) @staticmethod @@ -230,11 +263,12 @@ def from_string(key, password='notasecret'): Raises: NotImplementedError if they key isn't in PEM format. """ - if key.startswith('-----BEGIN '): - pkey = RSA.importKey(key) + parsed_pem_key = _parse_pem_key(key) + if parsed_pem_key: + pkey = RSA.importKey(parsed_pem_key) else: raise NotImplementedError( - 'PKCS12 format is not supported by the PyCrpto library. ' + 'PKCS12 format is not supported by the PyCrypto library. ' 'Try converting to a "PEM" ' '(openssl pkcs12 -in xxxxx.p12 -nodes -nocerts > privatekey.pem) ' 'or using PyOpenSSL if native code is an option.') @@ -256,19 +290,39 @@ def from_string(key, password='notasecret'): 'PyOpenSSL, or PyCrypto 2.6 or later') +def _parse_pem_key(raw_key_input): + """Identify and extract PEM keys. + + Determines whether the given key is in the format of PEM key, and extracts + the relevant part of the key if it is. + + Args: + raw_key_input: The contents of a private key file (either PEM or PKCS12). + + Returns: + string, The actual key if the contents are from a PEM file, or else None. + """ + offset = raw_key_input.find(b'-----BEGIN ') + if offset != -1: + return raw_key_input[offset:] + + def _urlsafe_b64encode(raw_bytes): - return base64.urlsafe_b64encode(raw_bytes).rstrip('=') + if isinstance(raw_bytes, six.text_type): + raw_bytes = raw_bytes.encode('utf-8') + return base64.urlsafe_b64encode(raw_bytes).decode('ascii').rstrip('=') def _urlsafe_b64decode(b64string): # Guard against unicode strings, which base64 can't handle. - b64string = b64string.encode('ascii') - padded = b64string + '=' * (4 - len(b64string) % 4) + if isinstance(b64string, six.text_type): + b64string = b64string.encode('ascii') + padded = b64string + b'=' * (4 - len(b64string) % 4) return base64.urlsafe_b64decode(padded) def _json_encode(data): - return simplejson.dumps(data, separators = (',', ':')) + return json.dumps(data, separators=(',', ':')) def make_signed_jwt(signer, payload): @@ -286,8 +340,8 @@ def make_signed_jwt(signer, payload): header = {'typ': 'JWT', 'alg': 'RS256'} segments = [ - _urlsafe_b64encode(_json_encode(header)), - _urlsafe_b64encode(_json_encode(payload)), + _urlsafe_b64encode(_json_encode(header)), + _urlsafe_b64encode(_json_encode(payload)), ] signing_input = '.'.join(segments) @@ -318,9 +372,8 @@ def verify_signed_jwt_with_certs(jwt, certs, audience): """ segments = jwt.split('.') - if (len(segments) != 3): - raise AppIdentityError( - 'Wrong number of segments in token: %s' % jwt) + if len(segments) != 3: + raise AppIdentityError('Wrong number of segments in token: %s' % jwt) signed = '%s.%s' % (segments[0], segments[1]) signature = _urlsafe_b64decode(segments[2]) @@ -328,15 +381,15 @@ def verify_signed_jwt_with_certs(jwt, certs, audience): # Parse token. json_body = _urlsafe_b64decode(segments[1]) try: - parsed = simplejson.loads(json_body) + parsed = json.loads(json_body.decode('utf-8')) except: raise AppIdentityError('Can\'t parse token: %s' % json_body) # Check signature. verified = False - for (keyname, pem) in certs.items(): + for pem in certs.values(): verifier = Verifier.from_string(pem, True) - if (verifier.verify(signed, signature)): + if verifier.verify(signed, signature): verified = True break if not verified: @@ -349,21 +402,20 @@ def verify_signed_jwt_with_certs(jwt, certs, audience): earliest = iat - CLOCK_SKEW_SECS # Check expiration timestamp. - now = long(time.time()) + now = int(time.time()) exp = parsed.get('exp') if exp is None: raise AppIdentityError('No exp field in token: %s' % json_body) if exp >= now + MAX_TOKEN_LIFETIME_SECS: - raise AppIdentityError( - 'exp field too far in future: %s' % json_body) + raise AppIdentityError('exp field too far in future: %s' % json_body) latest = exp + CLOCK_SKEW_SECS if now < earliest: raise AppIdentityError('Token used too early, %d < %d: %s' % - (now, earliest, json_body)) + (now, earliest, json_body)) if now > latest: raise AppIdentityError('Token used too late, %d > %d: %s' % - (now, latest, json_body)) + (now, latest, json_body)) # Check audience. if audience is not None: @@ -372,6 +424,6 @@ def verify_signed_jwt_with_certs(jwt, certs, audience): raise AppIdentityError('No aud field in token: %s' % json_body) if aud != audience: raise AppIdentityError('Wrong recipient, %s != %s: %s' % - (aud, audience, json_body)) + (aud, audience, json_body)) return parsed diff --git a/oauth2client/devshell.py b/oauth2client/devshell.py new file mode 100644 index 0000000..a33de87 --- /dev/null +++ b/oauth2client/devshell.py @@ -0,0 +1,136 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth 2.0 utitilies for Google Developer Shell environment.""" + +import json +import os + +from oauth2client import client + + +DEVSHELL_ENV = 'DEVSHELL_CLIENT_PORT' + + +class Error(Exception): + """Errors for this module.""" + pass + + +class CommunicationError(Error): + """Errors for communication with the Developer Shell server.""" + + +class NoDevshellServer(Error): + """Error when no Developer Shell server can be contacted.""" + + +# The request for credential information to the Developer Shell client socket is +# always an empty PBLite-formatted JSON object, so just define it as a constant. +CREDENTIAL_INFO_REQUEST_JSON = '[]' + + +class CredentialInfoResponse(object): + """Credential information response from Developer Shell server. + + The credential information response from Developer Shell socket is a + PBLite-formatted JSON array with fields encoded by their index in the array: + * Index 0 - user email + * Index 1 - default project ID. None if the project context is not known. + * Index 2 - OAuth2 access token. None if there is no valid auth context. + """ + + def __init__(self, json_string): + """Initialize the response data from JSON PBLite array.""" + pbl = json.loads(json_string) + if not isinstance(pbl, list): + raise ValueError('Not a list: ' + str(pbl)) + pbl_len = len(pbl) + self.user_email = pbl[0] if pbl_len > 0 else None + self.project_id = pbl[1] if pbl_len > 1 else None + self.access_token = pbl[2] if pbl_len > 2 else None + + +def _SendRecv(): + """Communicate with the Developer Shell server socket.""" + + port = int(os.getenv(DEVSHELL_ENV, 0)) + if port == 0: + raise NoDevshellServer() + + import socket + + sock = socket.socket() + sock.connect(('localhost', port)) + + data = CREDENTIAL_INFO_REQUEST_JSON + msg = '%s\n%s' % (len(data), data) + sock.sendall(msg.encode()) + + header = sock.recv(6).decode() + if '\n' not in header: + raise CommunicationError('saw no newline in the first 6 bytes') + len_str, json_str = header.split('\n', 1) + to_read = int(len_str) - len(json_str) + if to_read > 0: + json_str += sock.recv(to_read, socket.MSG_WAITALL).decode() + + return CredentialInfoResponse(json_str) + + +class DevshellCredentials(client.GoogleCredentials): + """Credentials object for Google Developer Shell environment. + + This object will allow a Google Developer Shell session to identify its user + to Google and other OAuth 2.0 servers that can verify assertions. It can be + used for the purpose of accessing data stored under the user account. + + This credential does not require a flow to instantiate because it represents + a two legged flow, and therefore has all of the required information to + generate and refresh its own access tokens. + """ + + def __init__(self, user_agent=None): + super(DevshellCredentials, self).__init__( + None, # access_token, initialized below + None, # client_id + None, # client_secret + None, # refresh_token + None, # token_expiry + None, # token_uri + user_agent) + self._refresh(None) + + def _refresh(self, http_request): + self.devshell_response = _SendRecv() + self.access_token = self.devshell_response.access_token + + @property + def user_email(self): + return self.devshell_response.user_email + + @property + def project_id(self): + return self.devshell_response.project_id + + @classmethod + def from_json(cls, json_data): + raise NotImplementedError( + 'Cannot load Developer Shell credentials from JSON.') + + @property + def serialization_data(self): + raise NotImplementedError( + 'Cannot serialize Developer Shell credentials.') + diff --git a/oauth2client/django_orm.py b/oauth2client/django_orm.py index d54d20c..65c5d20 100644 --- a/oauth2client/django_orm.py +++ b/oauth2client/django_orm.py @@ -1,4 +1,4 @@ -# Copyright (C) 2010 Google Inc. +# Copyright 2014 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -116,14 +116,21 @@ def locked_get(self): credential.set_store(self) return credential - def locked_put(self, credentials): + def locked_put(self, credentials, overwrite=False): """Write a Credentials to the datastore. Args: credentials: Credentials, the credentials to store. + overwrite: Boolean, indicates whether you would like these credentials to + overwrite any existing stored credentials. """ args = {self.key_name: self.key_value} - entity = self.model_class(**args) + + if overwrite: + entity, unused_is_new = self.model_class.objects.get_or_create(**args) + else: + entity = self.model_class(**args) + setattr(entity, self.property_name, credentials) entity.save() diff --git a/oauth2client/file.py b/oauth2client/file.py index 1895f94..9d0ae7f 100644 --- a/oauth2client/file.py +++ b/oauth2client/file.py @@ -1,4 +1,4 @@ -# Copyright (C) 2010 Google Inc. +# Copyright 2014 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,12 +21,10 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)' import os -import stat import threading -from anyjson import simplejson -from client import Storage as BaseStorage -from client import Credentials +from oauth2client.client import Credentials +from oauth2client.client import Storage as BaseStorage class CredentialsFileSymbolicLinkError(Exception): @@ -92,7 +90,7 @@ def _create_file_if_needed(self): simple version of "touch" to ensure the file has been created. """ if not os.path.exists(self._filename): - old_umask = os.umask(0177) + old_umask = os.umask(0o177) try: open(self._filename, 'a+b').close() finally: @@ -110,7 +108,7 @@ def locked_put(self, credentials): self._create_file_if_needed() self._validate_file() - f = open(self._filename, 'wb') + f = open(self._filename, 'w') f.write(credentials.to_json()) f.close() diff --git a/oauth2client/gce.py b/oauth2client/gce.py index c7fd7c1..fc3bd77 100644 --- a/oauth2client/gce.py +++ b/oauth2client/gce.py @@ -1,4 +1,4 @@ -# Copyright (C) 2012 Google Inc. +# Copyright 2014 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,12 +19,11 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)' -import httplib2 +import json import logging -import uritemplate +from six.moves import urllib from oauth2client import util -from oauth2client.anyjson import simplejson from oauth2client.client import AccessTokenRefreshError from oauth2client.client import AssertionCredentials @@ -57,13 +56,14 @@ def __init__(self, scope, **kwargs): requested. """ self.scope = util.scopes_to_string(scope) + self.kwargs = kwargs # Assertion type is no longer used, but still in the parent class signature. super(AppAssertionCredentials, self).__init__(None) @classmethod - def from_json(cls, json): - data = simplejson.loads(json) + def from_json(cls, json_data): + data = json.loads(json_data) return AppAssertionCredentials(data['scope']) def _refresh(self, http_request): @@ -78,13 +78,28 @@ def _refresh(self, http_request): Raises: AccessTokenRefreshError: When the refresh fails. """ - uri = uritemplate.expand(META, {'scope': self.scope}) + query = '?scope=%s' % urllib.parse.quote(self.scope, '') + uri = META.replace('{?scope}', query) response, content = http_request(uri) if response.status == 200: try: - d = simplejson.loads(content) - except StandardError, e: + d = json.loads(content) + except Exception as e: raise AccessTokenRefreshError(str(e)) self.access_token = d['accessToken'] else: + if response.status == 404: + content += (' This can occur if a VM was created' + ' with no service account or scopes.') raise AccessTokenRefreshError(content) + + @property + def serialization_data(self): + raise NotImplementedError( + 'Cannot serialize credentials for GCE service accounts.') + + def create_scoped_required(self): + return not self.scope + + def create_scoped(self, scopes): + return AppAssertionCredentials(scopes, **self.kwargs) diff --git a/oauth2client/keyring_storage.py b/oauth2client/keyring_storage.py index efe2949..cda1d9a 100644 --- a/oauth2client/keyring_storage.py +++ b/oauth2client/keyring_storage.py @@ -1,4 +1,4 @@ -# Copyright (C) 2012 Google Inc. +# Copyright 2014 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,11 +19,12 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)' -import keyring import threading -from client import Storage as BaseStorage -from client import Credentials +import keyring + +from oauth2client.client import Credentials +from oauth2client.client import Storage as BaseStorage class Storage(BaseStorage): diff --git a/oauth2client/locked_file.py b/oauth2client/locked_file.py index 26f783e..af92398 100644 --- a/oauth2client/locked_file.py +++ b/oauth2client/locked_file.py @@ -1,21 +1,37 @@ -# Copyright 2011 Google Inc. All Rights Reserved. +# Copyright 2014 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Locked file interface that should work on Unix and Windows pythons. This module first tries to use fcntl locking to ensure serialized access to a file, then falls back on a lock file if that is unavialable. -Usage: +Usage:: + f = LockedFile('filename', 'r+b', 'rb') f.open_and_lock() if f.is_locked(): - print 'Acquired filename with r+b mode' + print('Acquired filename with r+b mode') f.file_handle().write('locked data') else: - print 'Aquired filename with rb mode' + print('Acquired filename with rb mode') f.unlock_and_close() + """ +from __future__ import print_function + __author__ = 'cache@google.com (David T McWherter)' import errno @@ -58,6 +74,7 @@ def __init__(self, filename, mode, fallback_mode): self._mode = mode self._fallback_mode = fallback_mode self._fh = None + self._lock_fd = None def is_locked(self): """Was the file locked.""" @@ -110,7 +127,7 @@ def open_and_lock(self, timeout, delay): validate_file(self._filename) try: self._fh = open(self._filename, self._mode) - except IOError, e: + except IOError as e: # If we can't access with _mode, try _fallback_mode and don't lock. if e.errno == errno.EACCES: self._fh = open(self._filename, self._fallback_mode) @@ -125,12 +142,12 @@ def open_and_lock(self, timeout, delay): self._locked = True break - except OSError, e: + except OSError as e: if e.errno != errno.EEXIST: raise if (time.time() - start_time) >= timeout: - logger.warn('Could not acquire lock %s in %s seconds' % ( - lock_filename, timeout)) + logger.warn('Could not acquire lock %s in %s seconds', + lock_filename, timeout) # Close the file and open in fallback_mode. if self._fh: self._fh.close() @@ -180,9 +197,9 @@ def open_and_lock(self, timeout, delay): validate_file(self._filename) try: self._fh = open(self._filename, self._mode) - except IOError, e: + except IOError as e: # If we can't access with _mode, try _fallback_mode and don't lock. - if e.errno == errno.EACCES: + if e.errno in (errno.EPERM, errno.EACCES): self._fh = open(self._filename, self._fallback_mode) return @@ -192,16 +209,16 @@ def open_and_lock(self, timeout, delay): fcntl.lockf(self._fh.fileno(), fcntl.LOCK_EX) self._locked = True return - except IOError, e: + except IOError as e: # If not retrying, then just pass on the error. if timeout == 0: - raise e + raise if e.errno != errno.EACCES: - raise e + raise # We could not acquire the lock. Try again. if (time.time() - start_time) >= timeout: - logger.warn('Could not lock %s in %s seconds' % ( - self._filename, timeout)) + logger.warn('Could not lock %s in %s seconds', + self._filename, timeout) if self._fh: self._fh.close() self._fh = open(self._filename, self._fallback_mode) @@ -255,7 +272,7 @@ def open_and_lock(self, timeout, delay): validate_file(self._filename) try: self._fh = open(self._filename, self._mode) - except IOError, e: + except IOError as e: # If we can't access with _mode, try _fallback_mode and don't lock. if e.errno == errno.EACCES: self._fh = open(self._filename, self._fallback_mode) @@ -272,9 +289,9 @@ def open_and_lock(self, timeout, delay): pywintypes.OVERLAPPED()) self._locked = True return - except pywintypes.error, e: + except pywintypes.error as e: if timeout == 0: - raise e + raise # If the error is not that the file is already in use, raise. if e[0] != _Win32Opener.FILE_IN_USE_ERROR: @@ -296,7 +313,7 @@ def unlock_and_close(self): try: hfile = win32file._get_osfhandle(self._fh.fileno()) win32file.UnlockFileEx(hfile, 0, -0x10000, pywintypes.OVERLAPPED()) - except pywintypes.error, e: + except pywintypes.error as e: if e[0] != _Win32Opener.FILE_ALREADY_UNLOCKED_ERROR: raise self._locked = False diff --git a/oauth2client/multistore_file.py b/oauth2client/multistore_file.py index e1b39f7..f4ba4a7 100644 --- a/oauth2client/multistore_file.py +++ b/oauth2client/multistore_file.py @@ -1,4 +1,16 @@ -# Copyright 2011 Google Inc. All Rights Reserved. +# Copyright 2014 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Multi-credential file store with lock support. @@ -7,41 +19,43 @@ both in a single process and across processes. The credential themselves are keyed off of: + * client_id * user_agent * scope -The format of the stored data is like so: -{ - 'file_version': 1, - 'data': [ - { - 'key': { - 'clientId': '', - 'userAgent': '', - 'scope': '' - }, - 'credential': { - # JSON serialized Credentials. +The format of the stored data is like so:: + + { + 'file_version': 1, + 'data': [ + { + 'key': { + 'clientId': '', + 'userAgent': '', + 'scope': '' + }, + 'credential': { + # JSON serialized Credentials. + } } - } - ] -} + ] + } + """ __author__ = 'jbeda@google.com (Joe Beda)' -import base64 import errno +import json import logging import os import threading -from anyjson import simplejson -from oauth2client.client import Storage as BaseStorage from oauth2client.client import Credentials +from oauth2client.client import Storage as BaseStorage from oauth2client import util -from locked_file import LockedFile +from oauth2client.locked_file import LockedFile logger = logging.getLogger(__name__) @@ -52,12 +66,10 @@ class Error(Exception): """Base error for this module.""" - pass class NewerCredentialStoreError(Error): - """The credential store is a newer version that supported.""" - pass + """The credential store is a newer version than supported.""" @util.positional(4) @@ -125,6 +137,43 @@ def get_credential_storage_custom_key( An object derived from client.Storage for getting/setting the credential. """ + multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly) + key = util.dict_to_tuple_key(key_dict) + return multistore._get_storage(key) + + +@util.positional(1) +def get_all_credential_keys(filename, warn_on_readonly=True): + """Gets all the registered credential keys in the given Multistore. + + Args: + filename: The JSON file storing a set of credentials + warn_on_readonly: if True, log a warning if the store is readonly + + Returns: + A list of the credential keys present in the file. They are returned as + dictionaries that can be passed into get_credential_storage_custom_key to + get the actual credentials. + """ + multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly) + multistore._lock() + try: + return multistore._get_all_credential_keys() + finally: + multistore._unlock() + + +@util.positional(1) +def _get_multistore(filename, warn_on_readonly=True): + """A helper method to initialize the multistore with proper locking. + + Args: + filename: The JSON file storing a set of credentials + warn_on_readonly: if True, log a warning if the store is readonly + + Returns: + A multistore object + """ filename = os.path.expanduser(filename) _multistores_lock.acquire() try: @@ -132,8 +181,7 @@ def get_credential_storage_custom_key( filename, _MultiStore(filename, warn_on_readonly=warn_on_readonly)) finally: _multistores_lock.release() - key = util.dict_to_tuple_key(key_dict) - return multistore._get_storage(key) + return multistore class _MultiStore(object): @@ -145,7 +193,7 @@ def __init__(self, filename, warn_on_readonly=True): This will create the file if necessary. """ - self._file = LockedFile(filename, 'r+b', 'rb') + self._file = LockedFile(filename, 'r+', 'r') self._thread_lock = threading.Lock() self._read_only = False self._warn_on_readonly = warn_on_readonly @@ -223,7 +271,7 @@ def _create_file_if_needed(self): simple version of "touch" to ensure the file has been created. """ if not os.path.exists(self._file.filename()): - old_umask = os.umask(0177) + old_umask = os.umask(0o177) try: open(self._file.filename(), 'a+b').close() finally: @@ -232,13 +280,23 @@ def _create_file_if_needed(self): def _lock(self): """Lock the entire multistore.""" self._thread_lock.acquire() - self._file.open_and_lock() + try: + self._file.open_and_lock() + except IOError as e: + if e.errno == errno.ENOSYS: + logger.warn('File system does not support locking the credentials ' + 'file.') + elif e.errno == errno.ENOLCK: + logger.warn('File system is out of resources for writing the ' + 'credentials file (is your disk full?).') + else: + raise if not self._file.is_locked(): self._read_only = True if self._warn_on_readonly: logger.warn('The credentials file (%s) is not writable. Opening in ' 'read-only mode. Any refreshed credentials will only be ' - 'valid for this run.' % self._file.filename()) + 'valid for this run.', self._file.filename()) if os.path.getsize(self._file.filename()) == 0: logger.debug('Initializing empty multistore file') # The multistore is empty so write out an empty file. @@ -267,7 +325,7 @@ def _locked_json_read(self): """ assert self._thread_lock.locked() self._file.file_handle().seek(0) - return simplejson.load(self._file.file_handle()) + return json.load(self._file.file_handle()) def _locked_json_write(self, data): """Write a JSON serializable data structure to the multistore. @@ -281,7 +339,7 @@ def _locked_json_write(self, data): if self._read_only: return self._file.file_handle().seek(0) - simplejson.dump(data, self._file.file_handle(), sort_keys=True, indent=2) + json.dump(data, self._file.file_handle(), sort_keys=True, indent=2, separators=(',', ': ')) self._file.file_handle().truncate() def _refresh_data_cache(self): @@ -339,7 +397,7 @@ def _decode_credential_from_json(self, cred_entry): raw_key = cred_entry['key'] key = util.dict_to_tuple_key(raw_key) credential = None - credential = Credentials.new_from_json(simplejson.dumps(cred_entry['credential'])) + credential = Credentials.new_from_json(json.dumps(cred_entry['credential'])) return (key, credential) def _write(self): @@ -352,10 +410,18 @@ def _write(self): raw_data['data'] = raw_creds for (cred_key, cred) in self._data.items(): raw_key = dict(cred_key) - raw_cred = simplejson.loads(cred.to_json()) + raw_cred = json.loads(cred.to_json()) raw_creds.append({'key': raw_key, 'credential': raw_cred}) self._locked_json_write(raw_data) + def _get_all_credential_keys(self): + """Gets all the registered credential keys in the multistore. + + Returns: + A list of dictionaries corresponding to all the keys currently registered + """ + return [dict(key) for key in self._data.keys()] + def _get_credential(self, key): """Get a credential from the multistore. diff --git a/oauth2client/old_run.py b/oauth2client/old_run.py new file mode 100644 index 0000000..51db69b --- /dev/null +++ b/oauth2client/old_run.py @@ -0,0 +1,161 @@ +# Copyright 2014 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module holds the old run() function which is deprecated, the +tools.run_flow() function should be used in its place.""" + +from __future__ import print_function + +import logging +import socket +import sys +import webbrowser + +import gflags +from six.moves import input + +from oauth2client import client +from oauth2client import util +from oauth2client.tools import ClientRedirectHandler +from oauth2client.tools import ClientRedirectServer + + +FLAGS = gflags.FLAGS + +gflags.DEFINE_boolean('auth_local_webserver', True, + ('Run a local web server to handle redirects during ' + 'OAuth authorization.')) + +gflags.DEFINE_string('auth_host_name', 'localhost', + ('Host name to use when running a local web server to ' + 'handle redirects during OAuth authorization.')) + +gflags.DEFINE_multi_int('auth_host_port', [8080, 8090], + ('Port to use when running a local web server to ' + 'handle redirects during OAuth authorization.')) + + +@util.positional(2) +def run(flow, storage, http=None): + """Core code for a command-line application. + + The ``run()`` function is called from your application and runs + through all the steps to obtain credentials. It takes a ``Flow`` + argument and attempts to open an authorization server page in the + user's default web browser. The server asks the user to grant your + application access to the user's data. If the user grants access, + the ``run()`` function returns new credentials. The new credentials + are also stored in the ``storage`` argument, which updates the file + associated with the ``Storage`` object. + + It presumes it is run from a command-line application and supports the + following flags: + + ``--auth_host_name`` (string, default: ``localhost``) + Host name to use when running a local web server to handle + redirects during OAuth authorization. + + ``--auth_host_port`` (integer, default: ``[8080, 8090]``) + Port to use when running a local web server to handle redirects + during OAuth authorization. Repeat this option to specify a list + of values. + + ``--[no]auth_local_webserver`` (boolean, default: ``True``) + Run a local web server to handle redirects during OAuth authorization. + + Since it uses flags make sure to initialize the ``gflags`` module before + calling ``run()``. + + Args: + flow: Flow, an OAuth 2.0 Flow to step through. + storage: Storage, a ``Storage`` to store the credential in. + http: An instance of ``httplib2.Http.request`` or something that acts + like it. + + Returns: + Credentials, the obtained credential. + """ + logging.warning('This function, oauth2client.tools.run(), and the use of ' + 'the gflags library are deprecated and will be removed in a future ' + 'version of the library.') + if FLAGS.auth_local_webserver: + success = False + port_number = 0 + for port in FLAGS.auth_host_port: + port_number = port + try: + httpd = ClientRedirectServer((FLAGS.auth_host_name, port), + ClientRedirectHandler) + except socket.error as e: + pass + else: + success = True + break + FLAGS.auth_local_webserver = success + if not success: + print('Failed to start a local webserver listening on either port 8080') + print('or port 9090. Please check your firewall settings and locally') + print('running programs that may be blocking or using those ports.') + print() + print('Falling back to --noauth_local_webserver and continuing with') + print('authorization.') + print() + + if FLAGS.auth_local_webserver: + oauth_callback = 'http://%s:%s/' % (FLAGS.auth_host_name, port_number) + else: + oauth_callback = client.OOB_CALLBACK_URN + flow.redirect_uri = oauth_callback + authorize_url = flow.step1_get_authorize_url() + + if FLAGS.auth_local_webserver: + webbrowser.open(authorize_url, new=1, autoraise=True) + print('Your browser has been opened to visit:') + print() + print(' ' + authorize_url) + print() + print('If your browser is on a different machine then exit and re-run') + print('this application with the command-line parameter ') + print() + print(' --noauth_local_webserver') + print() + else: + print('Go to the following link in your browser:') + print() + print(' ' + authorize_url) + print() + + code = None + if FLAGS.auth_local_webserver: + httpd.handle_request() + if 'error' in httpd.query_params: + sys.exit('Authentication request was rejected.') + if 'code' in httpd.query_params: + code = httpd.query_params['code'] + else: + print('Failed to find "code" in the query parameters of the redirect.') + sys.exit('Try running with --noauth_local_webserver.') + else: + code = input('Enter verification code: ').strip() + + try: + credential = flow.step2_exchange(code, http=http) + except client.FlowExchangeError as e: + sys.exit('Authentication has failed: %s' % e) + + storage.put(credential) + credential.set_store(storage) + print('Authentication successful.') + + return credential diff --git a/oauth2client/service_account.py b/oauth2client/service_account.py new file mode 100644 index 0000000..d1d1d89 --- /dev/null +++ b/oauth2client/service_account.py @@ -0,0 +1,139 @@ +# Copyright 2014 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A service account credentials class. + +This credentials class is implemented on top of rsa library. +""" + +import base64 +import json +import six +import time + +from pyasn1.codec.ber import decoder +from pyasn1_modules.rfc5208 import PrivateKeyInfo +import rsa + +from oauth2client import GOOGLE_REVOKE_URI +from oauth2client import GOOGLE_TOKEN_URI +from oauth2client import util +from oauth2client.client import AssertionCredentials + + +class _ServiceAccountCredentials(AssertionCredentials): + """Class representing a service account (signed JWT) credential.""" + + MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds + + def __init__(self, service_account_id, service_account_email, private_key_id, + private_key_pkcs8_text, scopes, user_agent=None, + token_uri=GOOGLE_TOKEN_URI, revoke_uri=GOOGLE_REVOKE_URI, + **kwargs): + + super(_ServiceAccountCredentials, self).__init__( + None, user_agent=user_agent, token_uri=token_uri, revoke_uri=revoke_uri) + + self._service_account_id = service_account_id + self._service_account_email = service_account_email + self._private_key_id = private_key_id + self._private_key = _get_private_key(private_key_pkcs8_text) + self._private_key_pkcs8_text = private_key_pkcs8_text + self._scopes = util.scopes_to_string(scopes) + self._user_agent = user_agent + self._token_uri = token_uri + self._revoke_uri = revoke_uri + self._kwargs = kwargs + + def _generate_assertion(self): + """Generate the assertion that will be used in the request.""" + + header = { + 'alg': 'RS256', + 'typ': 'JWT', + 'kid': self._private_key_id + } + + now = int(time.time()) + payload = { + 'aud': self._token_uri, + 'scope': self._scopes, + 'iat': now, + 'exp': now + _ServiceAccountCredentials.MAX_TOKEN_LIFETIME_SECS, + 'iss': self._service_account_email + } + payload.update(self._kwargs) + + assertion_input = (_urlsafe_b64encode(header) + b'.' + + _urlsafe_b64encode(payload)) + + # Sign the assertion. + rsa_bytes = rsa.pkcs1.sign(assertion_input, self._private_key, 'SHA-256') + signature = base64.urlsafe_b64encode(rsa_bytes).rstrip(b'=') + + return assertion_input + b'.' + signature + + def sign_blob(self, blob): + # Ensure that it is bytes + try: + blob = blob.encode('utf-8') + except AttributeError: + pass + return (self._private_key_id, + rsa.pkcs1.sign(blob, self._private_key, 'SHA-256')) + + @property + def service_account_email(self): + return self._service_account_email + + @property + def serialization_data(self): + return { + 'type': 'service_account', + 'client_id': self._service_account_id, + 'client_email': self._service_account_email, + 'private_key_id': self._private_key_id, + 'private_key': self._private_key_pkcs8_text + } + + def create_scoped_required(self): + return not self._scopes + + def create_scoped(self, scopes): + return _ServiceAccountCredentials(self._service_account_id, + self._service_account_email, + self._private_key_id, + self._private_key_pkcs8_text, + scopes, + user_agent=self._user_agent, + token_uri=self._token_uri, + revoke_uri=self._revoke_uri, + **self._kwargs) + + +def _urlsafe_b64encode(data): + return base64.urlsafe_b64encode( + json.dumps(data, separators=(',', ':')).encode('UTF-8')).rstrip(b'=') + + +def _get_private_key(private_key_pkcs8_text): + """Get an RSA private key object from a pkcs8 representation.""" + + if not isinstance(private_key_pkcs8_text, six.binary_type): + private_key_pkcs8_text = private_key_pkcs8_text.encode('ascii') + der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY') + asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo()) + return rsa.PrivateKey.load_pkcs1( + asn1_private_key.getComponentByName('privateKey').asOctets(), + format='DER') diff --git a/oauth2client/tools.py b/oauth2client/tools.py index 6733230..3c72903 100644 --- a/oauth2client/tools.py +++ b/oauth2client/tools.py @@ -1,4 +1,4 @@ -# Copyright (C) 2010 Google Inc. +# Copyright 2014 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,39 +19,55 @@ the same directory. """ -__author__ = 'jcgregorio@google.com (Joe Gregorio)' -__all__ = ['run'] +from __future__ import print_function +__author__ = 'jcgregorio@google.com (Joe Gregorio)' +__all__ = ['argparser', 'run_flow', 'run', 'message_if_missing'] -import BaseHTTPServer -import gflags +import logging import socket import sys -import webbrowser -from oauth2client.client import FlowExchangeError -from oauth2client.client import OOB_CALLBACK_URN +from six.moves import BaseHTTPServer +from six.moves import urllib +from six.moves import input + +from oauth2client import client from oauth2client import util -try: - from urlparse import parse_qsl -except ImportError: - from cgi import parse_qsl +_CLIENT_SECRETS_MESSAGE = """WARNING: Please configure OAuth 2.0 -FLAGS = gflags.FLAGS +To make this sample run you will need to populate the client_secrets.json file +found at: -gflags.DEFINE_boolean('auth_local_webserver', True, - ('Run a local web server to handle redirects during ' - 'OAuth authorization.')) + %s -gflags.DEFINE_string('auth_host_name', 'localhost', - ('Host name to use when running a local web server to ' - 'handle redirects during OAuth authorization.')) +with information from the APIs Console . -gflags.DEFINE_multi_int('auth_host_port', [8080, 8090], - ('Port to use when running a local web server to ' - 'handle redirects during OAuth authorization.')) +""" + +def _CreateArgumentParser(): + try: + import argparse + except ImportError: + return None + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument('--auth_host_name', default='localhost', + help='Hostname when running a local web server.') + parser.add_argument('--noauth_local_webserver', action='store_true', + default=False, help='Do not run a local web server.') + parser.add_argument('--auth_host_port', default=[8080, 8090], type=int, + nargs='*', help='Port web server should listen on.') + parser.add_argument('--logging_level', default='ERROR', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + help='Set the logging level of detail.') + return parser + +# argparser is an ArgumentParser that contains command-line options expected +# by tools.run(). Pass it in as part of the 'parents' argument to your own +# ArgumentParser. +argparser = _CreateArgumentParser() class ClientRedirectServer(BaseHTTPServer.HTTPServer): @@ -70,147 +86,172 @@ class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler): into the servers query_params and then stops serving. """ - def do_GET(s): + def do_GET(self): """Handle a GET request. Parses the query parameters and prints a message if the flow has completed. Note that we can't detect if an error occurred. """ - s.send_response(200) - s.send_header("Content-type", "text/html") - s.end_headers() - query = s.path.split('?', 1)[-1] - query = dict(parse_qsl(query)) - s.server.query_params = query - s.wfile.write("Authentication Status") - s.wfile.write("

The authentication flow has completed.

") - s.wfile.write("") + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + query = self.path.split('?', 1)[-1] + query = dict(urllib.parse.parse_qsl(query)) + self.server.query_params = query + self.wfile.write(b"Authentication Status") + self.wfile.write(b"

The authentication flow has completed.

") + self.wfile.write(b"") def log_message(self, format, *args): """Do not log messages to stdout while running as command line program.""" - pass -@util.positional(2) -def run(flow, storage, http=None, short_url=False): +@util.positional(3) +def run_flow(flow, storage, flags, http=None): """Core code for a command-line application. - The run() function is called from your application and runs through all the - steps to obtain credentials. It takes a Flow argument and attempts to open an - authorization server page in the user's default web browser. The server asks - the user to grant your application access to the user's data. If the user - grants access, the run() function returns new credentials. The new credentials - are also stored in the Storage argument, which updates the file associated - with the Storage object. + The ``run()`` function is called from your application and runs + through all the steps to obtain credentials. It takes a ``Flow`` + argument and attempts to open an authorization server page in the + user's default web browser. The server asks the user to grant your + application access to the user's data. If the user grants access, + the ``run()`` function returns new credentials. The new credentials + are also stored in the ``storage`` argument, which updates the file + associated with the ``Storage`` object. It presumes it is run from a command-line application and supports the following flags: - --auth_host_name: Host name to use when running a local web server - to handle redirects during OAuth authorization. - (default: 'localhost') + ``--auth_host_name`` (string, default: ``localhost``) + Host name to use when running a local web server to handle + redirects during OAuth authorization. + + ``--auth_host_port`` (integer, default: ``[8080, 8090]``) + Port to use when running a local web server to handle redirects + during OAuth authorization. Repeat this option to specify a list + of values. + + ``--[no]auth_local_webserver`` (boolean, default: ``True``) + Run a local web server to handle redirects during OAuth authorization. - --auth_host_port: Port to use when running a local web server to handle - redirects during OAuth authorization.; - repeat this option to specify a list of values - (default: '[8080, 8090]') - (an integer) - --[no]auth_local_webserver: Run a local web server to handle redirects - during OAuth authorization. - (default: 'true') - Since it uses flags make sure to initialize the gflags module before calling - run(). + + The tools module defines an ``ArgumentParser`` the already contains the flag + definitions that ``run()`` requires. You can pass that ``ArgumentParser`` to your + ``ArgumentParser`` constructor:: + + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + parents=[tools.argparser]) + flags = parser.parse_args(argv) Args: flow: Flow, an OAuth 2.0 Flow to step through. - storage: Storage, a Storage to store the credential in. - http: An instance of httplib2.Http.request - or something that acts like it. + storage: Storage, a ``Storage`` to store the credential in. + flags: ``argparse.Namespace``, The command-line flags. This is the + object returned from calling ``parse_args()`` on + ``argparse.ArgumentParser`` as described above. + http: An instance of ``httplib2.Http.request`` or something that + acts like it. Returns: Credentials, the obtained credential. """ - if FLAGS.auth_local_webserver: + logging.getLogger().setLevel(getattr(logging, flags.logging_level)) + if not flags.noauth_local_webserver: success = False port_number = 0 - for port in FLAGS.auth_host_port: + for port in flags.auth_host_port: port_number = port try: - httpd = ClientRedirectServer((FLAGS.auth_host_name, port), + httpd = ClientRedirectServer((flags.auth_host_name, port), ClientRedirectHandler) - except socket.error, e: + except socket.error: pass else: success = True break - FLAGS.auth_local_webserver = success + flags.noauth_local_webserver = not success if not success: - print 'Failed to start a local webserver listening on either port 8080' - print 'or port 9090. Please check your firewall settings and locally' - print 'running programs that may be blocking or using those ports.' - print - print 'Falling back to --noauth_local_webserver and continuing with', - print 'authorization.' - print - - if FLAGS.auth_local_webserver: - oauth_callback = 'http://%s:%s/' % (FLAGS.auth_host_name, port_number) + print('Failed to start a local webserver listening on either port 8080') + print('or port 9090. Please check your firewall settings and locally') + print('running programs that may be blocking or using those ports.') + print() + print('Falling back to --noauth_local_webserver and continuing with') + print('authorization.') + print() + + if not flags.noauth_local_webserver: + oauth_callback = 'http://%s:%s/' % (flags.auth_host_name, port_number) else: - oauth_callback = OOB_CALLBACK_URN + oauth_callback = client.OOB_CALLBACK_URN flow.redirect_uri = oauth_callback authorize_url = flow.step1_get_authorize_url() - - if short_url: + + if flags.short_url: try: - from apiclient.discovery import build + from googleapiclient.discovery import build service = build('urlshortener', 'v1', http=http) url_result = service.url().insert(body={'longUrl': authorize_url}, key=u'AIzaSyBlmgbii8QfJSYmC9VTMOfqrAt5Vj5wtzE').execute() authorize_url = url_result['id'] except: pass - if FLAGS.auth_local_webserver: + + if not flags.noauth_local_webserver: + import webbrowser webbrowser.open(authorize_url, new=1, autoraise=True) - print 'Your browser has been opened to visit:' - print - print ' ' + authorize_url - print - print 'If your browser is on a different machine then exit and re-run this' - print 'after creating a file called nobrowser.txt in the same path as GYB' - print -# print 'application with the command-line parameter ' -# print -# print ' --noauth_local_webserver' -# print + print('Your browser has been opened to visit:') + print() + print(' ' + authorize_url) + print() + print('If your browser is on a different machine then exit and re-run this') + print('after creating a file called nobrowser.txt in the same path as GAM.') + print() else: - print 'Go to the following link in your browser:' - print - print ' ' + authorize_url - print + print('Go to the following link in your browser:') + print() + print(' ' + authorize_url) + print() code = None - if FLAGS.auth_local_webserver: + if not flags.noauth_local_webserver: httpd.handle_request() if 'error' in httpd.query_params: sys.exit('Authentication request was rejected.') if 'code' in httpd.query_params: code = httpd.query_params['code'] else: - print 'Failed to find "code" in the query parameters of the redirect.' + print('Failed to find "code" in the query parameters of the redirect.') sys.exit('Try running with --noauth_local_webserver.') else: - code = raw_input('Enter verification code: ').strip() + code = input('Enter verification code: ').strip() try: credential = flow.step2_exchange(code, http=http) - except FlowExchangeError, e: + except client.FlowExchangeError as e: sys.exit('Authentication has failed: %s' % e) storage.put(credential) credential.set_store(storage) - print 'Authentication successful.' + print('Authentication successful.') return credential + + +def message_if_missing(filename): + """Helpful message to display if the CLIENT_SECRETS file is missing.""" + + return _CLIENT_SECRETS_MESSAGE % filename + +try: + from oauth2client.old_run import run + from oauth2client.old_run import FLAGS +except ImportError: + def run(*args, **kwargs): + raise NotImplementedError( + 'The gflags library must be installed to use tools.run(). ' + 'Please install gflags or preferrably switch to using ' + 'tools.run_flow().') diff --git a/oauth2client/util.py b/oauth2client/util.py index ee6a100..a706f02 100644 --- a/oauth2client/util.py +++ b/oauth2client/util.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright 2010 Google Inc. +# Copyright 2014 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,86 +17,93 @@ """Common utility library.""" -__author__ = ['rafek@google.com (Rafe Kaplan)', - 'guido@google.com (Guido van Rossum)', +__author__ = [ + 'rafek@google.com (Rafe Kaplan)', + 'guido@google.com (Guido van Rossum)', ] + __all__ = [ - 'positional', + 'positional', + 'POSITIONAL_WARNING', + 'POSITIONAL_EXCEPTION', + 'POSITIONAL_IGNORE', ] -import gflags +import functools import inspect import logging +import sys import types -import urllib -import urlparse -try: - from urlparse import parse_qsl -except ImportError: - from cgi import parse_qsl +import six +from six.moves import urllib -logger = logging.getLogger(__name__) -FLAGS = gflags.FLAGS +logger = logging.getLogger(__name__) -gflags.DEFINE_enum('positional_parameters_enforcement', 'WARNING', - ['EXCEPTION', 'WARNING', 'IGNORE'], - 'The action when an oauth2client.util.positional declaration is violated.') +POSITIONAL_WARNING = 'WARNING' +POSITIONAL_EXCEPTION = 'EXCEPTION' +POSITIONAL_IGNORE = 'IGNORE' +POSITIONAL_SET = frozenset([POSITIONAL_WARNING, POSITIONAL_EXCEPTION, + POSITIONAL_IGNORE]) +positional_parameters_enforcement = POSITIONAL_WARNING def positional(max_positional_args): """A decorator to declare that only the first N arguments my be positional. - This decorator makes it easy to support Python 3 style key-word only - parameters. For example, in Python 3 it is possible to write: + This decorator makes it easy to support Python 3 style keyword-only + parameters. For example, in Python 3 it is possible to write:: def fn(pos1, *, kwonly1=None, kwonly1=None): ... - All named parameters after * must be a keyword: + All named parameters after ``*`` must be a keyword:: fn(10, 'kw1', 'kw2') # Raises exception. fn(10, kwonly1='kw1') # Ok. - Example: - To define a function like above, do: + Example + ^^^^^^^ - @positional(1) - def fn(pos1, kwonly1=None, kwonly2=None): - ... + To define a function like above, do:: - If no default value is provided to a keyword argument, it becomes a required - keyword argument: + @positional(1) + def fn(pos1, kwonly1=None, kwonly2=None): + ... - @positional(0) - def fn(required_kw): - ... + If no default value is provided to a keyword argument, it becomes a required + keyword argument:: - This must be called with the keyword parameter: + @positional(0) + def fn(required_kw): + ... - fn() # Raises exception. - fn(10) # Raises exception. - fn(required_kw=10) # Ok. + This must be called with the keyword parameter:: - When defining instance or class methods always remember to account for - 'self' and 'cls': + fn() # Raises exception. + fn(10) # Raises exception. + fn(required_kw=10) # Ok. - class MyClass(object): + When defining instance or class methods always remember to account for + ``self`` and ``cls``:: - @positional(2) - def my_method(self, pos1, kwonly1=None): - ... + class MyClass(object): - @classmethod - @positional(2) - def my_method(cls, pos1, kwonly1=None): - ... + @positional(2) + def my_method(self, pos1, kwonly1=None): + ... - The positional decorator behavior is controlled by the - --positional_parameters_enforcement flag. The flag may be set to 'EXCEPTION', - 'WARNING' or 'IGNORE' to raise an exception, log a warning, or do nothing, - respectively, if a declaration is violated. + @classmethod + @positional(2) + def my_method(cls, pos1, kwonly1=None): + ... + + The positional decorator behavior is controlled by + ``util.positional_parameters_enforcement``, which may be set to + ``POSITIONAL_EXCEPTION``, ``POSITIONAL_WARNING`` or + ``POSITIONAL_IGNORE`` to raise an exception, log a warning, or do + nothing, respectively, if a declaration is violated. Args: max_positional_arguments: Maximum number of positional arguments. All @@ -107,11 +114,13 @@ def my_method(cls, pos1, kwonly1=None): being used as positional parameters. Raises: - TypeError if a key-word only argument is provided as a positional parameter, - but only if the --positional_parameters_enforcement flag is set to - 'EXCEPTION'. + TypeError if a key-word only argument is provided as a positional + parameter, but only if util.positional_parameters_enforcement is set to + POSITIONAL_EXCEPTION. + """ def positional_decorator(wrapped): + @functools.wraps(wrapped) def positional_wrapper(*args, **kwargs): if len(args) > max_positional_args: plural_s = '' @@ -119,16 +128,16 @@ def positional_wrapper(*args, **kwargs): plural_s = 's' message = '%s() takes at most %d positional argument%s (%d given)' % ( wrapped.__name__, max_positional_args, plural_s, len(args)) - if FLAGS.positional_parameters_enforcement == 'EXCEPTION': + if positional_parameters_enforcement == POSITIONAL_EXCEPTION: raise TypeError(message) - elif FLAGS.positional_parameters_enforcement == 'WARNING': + elif positional_parameters_enforcement == POSITIONAL_WARNING: logger.warning(message) else: # IGNORE pass return wrapped(*args, **kwargs) return positional_wrapper - if isinstance(max_positional_args, (int, long)): + if isinstance(max_positional_args, six.integer_types): return positional_decorator else: args, _, _, defaults = inspect.getargspec(max_positional_args) @@ -148,7 +157,7 @@ def scopes_to_string(scopes): Returns: The scopes formatted as a single string. """ - if isinstance(scopes, types.StringTypes): + if isinstance(scopes, six.string_types): return scopes else: return ' '.join(scopes) @@ -185,8 +194,8 @@ def _add_query_parameter(url, name, value): if value is None: return url else: - parsed = list(urlparse.urlparse(url)) - q = dict(parse_qsl(parsed[4])) + parsed = list(urllib.parse.urlparse(url)) + q = dict(urllib.parse.parse_qsl(parsed[4])) q[name] = value - parsed[4] = urllib.urlencode(q) - return urlparse.urlunparse(parsed) + parsed[4] = urllib.parse.urlencode(q) + return urllib.parse.urlunparse(parsed) diff --git a/oauth2client/xsrfutil.py b/oauth2client/xsrfutil.py index 7e1fe5c..5739dcf 100644 --- a/oauth2client/xsrfutil.py +++ b/oauth2client/xsrfutil.py @@ -1,6 +1,5 @@ -#!/usr/bin/python2.5 # -# Copyright 2010 the Melange authors. +# Copyright 2014 the Melange authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,25 +16,36 @@ """Helper methods for creating & verifying XSRF tokens.""" __authors__ = [ - '"Doug Coker" ', - '"Joe Gregorio" ', + '"Doug Coker" ', + '"Joe Gregorio" ', ] import base64 import hmac -import os # for urandom import time +import six from oauth2client import util # Delimiter character -DELIMITER = ':' +DELIMITER = b':' + # 1 hour in seconds DEFAULT_TIMEOUT_SECS = 1*60*60 + +def _force_bytes(s): + if isinstance(s, bytes): + return s + s = str(s) + if isinstance(s, six.text_type): + return s.encode('utf-8') + return s + + @util.positional(2) def generate_token(key, user_id, action_id="", when=None): """Generates a URL-safe token for the given user, action, time tuple. @@ -51,18 +61,16 @@ def generate_token(key, user_id, action_id="", when=None): Returns: A string XSRF protection token. """ - when = when or int(time.time()) - digester = hmac.new(key) - digester.update(str(user_id)) + when = _force_bytes(when or int(time.time())) + digester = hmac.new(_force_bytes(key)) + digester.update(_force_bytes(user_id)) digester.update(DELIMITER) - digester.update(action_id) + digester.update(_force_bytes(action_id)) digester.update(DELIMITER) - digester.update(str(when)) + digester.update(when) digest = digester.digest() - token = base64.urlsafe_b64encode('%s%s%d' % (digest, - DELIMITER, - when)) + token = base64.urlsafe_b64encode(digest + DELIMITER + when) return token @@ -87,8 +95,8 @@ def validate_token(key, token, user_id, action_id="", current_time=None): if not token: return False try: - decoded = base64.urlsafe_b64decode(str(token)) - token_time = long(decoded.split(DELIMITER)[-1]) + decoded = base64.urlsafe_b64decode(token) + token_time = int(decoded.split(DELIMITER)[-1]) except (TypeError, ValueError): return False if current_time is None: @@ -105,9 +113,6 @@ def validate_token(key, token, user_id, action_id="", current_time=None): # Perform constant time comparison to avoid timing attacks different = 0 - for x, y in zip(token, expected_token): - different |= ord(x) ^ ord(y) - if different: - return False - - return True + for x, y in zip(bytearray(token), bytearray(expected_token)): + different |= x ^ y + return not different diff --git a/uritemplate/__init__.py b/uritemplate/__init__.py index 5d0ebce..712405d 100644 --- a/uritemplate/__init__.py +++ b/uritemplate/__init__.py @@ -1,147 +1,265 @@ -# Early, and incomplete implementation of -04. -# +#!/usr/bin/env python + +""" +URI Template (RFC6570) Processor +""" + +__copyright__ = """\ +Copyright 2011-2013 Joe Gregorio + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + import re -import urllib +try: + from urllib.parse import quote +except ImportError: + from urllib import quote + + + +__version__ = "0.6" RESERVED = ":/?#[]@!$&'()*+,;=" -OPERATOR = "+./;?|!@" -EXPLODE = "*+" +OPERATOR = "+#./;?&|!@" MODIFIER = ":^" -TEMPLATE = re.compile(r"{(?P[\+\./;\?|!@])?(?P[^}]+)}", re.UNICODE) -VAR = re.compile(r"^(?P[^=\+\*:\^]+)((?P[\+\*])|(?P[:\^]-?[0-9]+))?(=(?P.*))?$", re.UNICODE) +TEMPLATE = re.compile("{([^\}]+)}") -def _tostring(varname, value, explode, operator, safe=""): - if type(value) == type([]): - if explode == "+": - return ",".join([varname + "." + urllib.quote(x, safe) for x in value]) - else: - return ",".join([urllib.quote(x, safe) for x in value]) - if type(value) == type({}): - keys = value.keys() - keys.sort() - if explode == "+": - return ",".join([varname + "." + urllib.quote(key, safe) + "," + urllib.quote(value[key], safe) for key in keys]) - else: - return ",".join([urllib.quote(key, safe) + "," + urllib.quote(value[key], safe) for key in keys]) - else: - return urllib.quote(value, safe) - - -def _tostring_path(varname, value, explode, operator, safe=""): - joiner = operator - if type(value) == type([]): - if explode == "+": - return joiner.join([varname + "." + urllib.quote(x, safe) for x in value]) - elif explode == "*": - return joiner.join([urllib.quote(x, safe) for x in value]) - else: - return ",".join([urllib.quote(x, safe) for x in value]) - elif type(value) == type({}): - keys = value.keys() - keys.sort() - if explode == "+": - return joiner.join([varname + "." + urllib.quote(key, safe) + joiner + urllib.quote(value[key], safe) for key in keys]) - elif explode == "*": - return joiner.join([urllib.quote(key, safe) + joiner + urllib.quote(value[key], safe) for key in keys]) - else: - return ",".join([urllib.quote(key, safe) + "," + urllib.quote(value[key], safe) for key in keys]) - else: - if value: - return urllib.quote(value, safe) + +def variables(template): + '''Returns the set of keywords in a uri template''' + vars = set() + for varlist in TEMPLATE.findall(template): + if varlist[0] in OPERATOR: + varlist = varlist[1:] + varspecs = varlist.split(',') + for var in varspecs: + # handle prefix values + var = var.split(':')[0] + # handle composite values + if var.endswith('*'): + var = var[:-1] + vars.add(var) + return vars + + +def _quote(value, safe, prefix=None): + if prefix is not None: + return quote(str(value)[:prefix], safe) + return quote(str(value), safe) + + +def _tostring(varname, value, explode, prefix, operator, safe=""): + if isinstance(value, list): + return ",".join([_quote(x, safe) for x in value]) + if isinstance(value, dict): + keys = sorted(value.keys()) + if explode: + return ",".join([_quote(key, safe) + "=" + \ + _quote(value[key], safe) for key in keys]) + else: + return ",".join([_quote(key, safe) + "," + \ + _quote(value[key], safe) for key in keys]) + elif value is None: + return else: - return "" - -def _tostring_query(varname, value, explode, operator, safe=""): - joiner = operator - varprefix = "" - if operator == "?": - joiner = "&" - varprefix = varname + "=" - if type(value) == type([]): - if 0 == len(value): - return "" - if explode == "+": - return joiner.join([varname + "=" + urllib.quote(x, safe) for x in value]) - elif explode == "*": - return joiner.join([urllib.quote(x, safe) for x in value]) + return _quote(value, safe, prefix) + + +def _tostring_path(varname, value, explode, prefix, operator, safe=""): + joiner = operator + if isinstance(value, list): + if explode: + out = [_quote(x, safe) for x in value if value is not None] + else: + joiner = "," + out = [_quote(x, safe) for x in value if value is not None] + if out: + return joiner.join(out) + else: + return + elif isinstance(value, dict): + keys = sorted(value.keys()) + if explode: + out = [_quote(key, safe) + "=" + \ + _quote(value[key], safe) for key in keys \ + if value[key] is not None] + else: + joiner = "," + out = [_quote(key, safe) + "," + \ + _quote(value[key], safe) \ + for key in keys if value[key] is not None] + if out: + return joiner.join(out) + else: + return + elif value is None: + return else: - return varprefix + ",".join([urllib.quote(x, safe) for x in value]) - elif type(value) == type({}): - if 0 == len(value): - return "" - keys = value.keys() - keys.sort() - if explode == "+": - return joiner.join([varname + "." + urllib.quote(key, safe) + "=" + urllib.quote(value[key], safe) for key in keys]) - elif explode == "*": - return joiner.join([urllib.quote(key, safe) + "=" + urllib.quote(value[key], safe) for key in keys]) + return _quote(value, safe, prefix) + + +def _tostring_semi(varname, value, explode, prefix, operator, safe=""): + joiner = operator + if operator == "?": + joiner = "&" + if isinstance(value, list): + if explode: + out = [varname + "=" + _quote(x, safe) \ + for x in value if x is not None] + if out: + return joiner.join(out) + else: + return + else: + return varname + "=" + ",".join([_quote(x, safe) \ + for x in value]) + elif isinstance(value, dict): + keys = sorted(value.keys()) + if explode: + return joiner.join([_quote(key, safe) + "=" + \ + _quote(value[key], safe) \ + for key in keys if key is not None]) + else: + return varname + "=" + ",".join([_quote(key, safe) + "," + \ + _quote(value[key], safe) for key in keys \ + if key is not None]) else: - return varprefix + ",".join([urllib.quote(key, safe) + "," + urllib.quote(value[key], safe) for key in keys]) - else: - if value: - return varname + "=" + urllib.quote(value, safe) + if value is None: + return + elif value: + return (varname + "=" + _quote(value, safe, prefix)) + else: + return varname + + +def _tostring_query(varname, value, explode, prefix, operator, safe=""): + joiner = operator + if operator in ["?", "&"]: + joiner = "&" + if isinstance(value, list): + if 0 == len(value): + return None + if explode: + return joiner.join([varname + "=" + _quote(x, safe) \ + for x in value]) + else: + return (varname + "=" + ",".join([_quote(x, safe) \ + for x in value])) + elif isinstance(value, dict): + if 0 == len(value): + return None + keys = sorted(value.keys()) + if explode: + return joiner.join([_quote(key, safe) + "=" + \ + _quote(value[key], safe) \ + for key in keys]) + else: + return varname + "=" + \ + ",".join([_quote(key, safe) + "," + \ + _quote(value[key], safe) for key in keys]) else: - return varname + if value is None: + return + elif value: + return (varname + "=" + _quote(value, safe, prefix)) + else: + return (varname + "=") + TOSTRING = { "" : _tostring, "+": _tostring, - ";": _tostring_query, + "#": _tostring, + ";": _tostring_semi, "?": _tostring_query, + "&": _tostring_query, "/": _tostring_path, ".": _tostring_path, } -def expand(template, vars): - def _sub(match): - groupdict = match.groupdict() - operator = groupdict.get('operator') - if operator is None: - operator = '' - varlist = groupdict.get('varlist') - - safe = "@" - if operator == '+': - safe = RESERVED - varspecs = varlist.split(",") - varnames = [] - defaults = {} - for varspec in varspecs: - m = VAR.search(varspec) - groupdict = m.groupdict() - varname = groupdict.get('varname') - explode = groupdict.get('explode') - partial = groupdict.get('partial') - default = groupdict.get('default') - if default: - defaults[varname] = default - varnames.append((varname, explode, partial)) - - retval = [] - joiner = operator - prefix = operator - if operator == "+": - prefix = "" - joiner = "," - if operator == "?": - joiner = "&" - if operator == "": - joiner = "," - for varname, explode, partial in varnames: - if varname in vars: - value = vars[varname] - #if not value and (type(value) == type({}) or type(value) == type([])) and varname in defaults: - if not value and value != "" and varname in defaults: - value = defaults[varname] - elif varname in defaults: - value = defaults[varname] - else: - continue - retval.append(TOSTRING[operator](varname, value, explode, operator, safe=safe)) - if "".join(retval): - return prefix + joiner.join(retval) - else: - return "" +def expand(template, variables): + """ + Expand template as a URI Template using variables. + """ + def _sub(match): + expression = match.group(1) + operator = "" + if expression[0] in OPERATOR: + operator = expression[0] + varlist = expression[1:] + else: + varlist = expression + + safe = "" + if operator in ["+", "#"]: + safe = RESERVED + varspecs = varlist.split(",") + varnames = [] + defaults = {} + for varspec in varspecs: + default = None + explode = False + prefix = None + if "=" in varspec: + varname, default = tuple(varspec.split("=", 1)) + else: + varname = varspec + if varname[-1] == "*": + explode = True + varname = varname[:-1] + elif ":" in varname: + try: + prefix = int(varname[varname.index(":")+1:]) + except ValueError: + raise ValueError("non-integer prefix '{0}'".format( + varname[varname.index(":")+1:])) + varname = varname[:varname.index(":")] + if default: + defaults[varname] = default + varnames.append((varname, explode, prefix)) + + retval = [] + joiner = operator + start = operator + if operator == "+": + start = "" + joiner = "," + if operator == "#": + joiner = "," + if operator == "?": + joiner = "&" + if operator == "&": + start = "&" + if operator == "": + joiner = "," + for varname, explode, prefix in varnames: + if varname in variables: + value = variables[varname] + if not value and value != "" and varname in defaults: + value = defaults[varname] + elif varname in defaults: + value = defaults[varname] + else: + continue + expanded = TOSTRING[operator]( + varname, value, explode, prefix, operator, safe=safe) + if expanded is not None: + retval.append(expanded) + if len(retval) > 0: + return start + joiner.join(retval) + else: + return "" - return TEMPLATE.sub(_sub, template) + return TEMPLATE.sub(_sub, template)