#!/usr/bin/python -u

from httplib2 import Http
from oauth import oauth
from urlparse import urlparse
from cgi import parse_qsl
from json import loads, dumps
from urllib import quote, unquote
from optparse import OptionParser
import os

import readline
import shlex # no unicode support?
import mimetypes
import cmd
import time
import logging
import sys

def setup_logging():
    """Sets up logging to stderr"""
    root_formatter = logging.Formatter(
            fmt="%(asctime)s: %(message)s")

    root_handler = logging.StreamHandler(sys.stderr)
    root_handler.setFormatter(root_formatter)

    instance = logging.getLogger('ubuntuone.restfilesclient')
    instance.propagate = False
    instance.addHandler(root_handler)
    return instance

logger = setup_logging()

class ServerError(Exception):
    pass

class NotFoundError(ServerError):
    pass

class InternalError(ServerError):
    pass

class UserError(Exception):
    pass

class OAuthHttpClient(object):
    """This is a very simple OAuth-enabled HTTP client"""
    def __init__(self):
        self.signature_method = oauth.OAuthSignatureMethod_HMAC_SHA1()
        self.consumer = None
        self.token = None
        self.client = Http()

    def set_consumer(self, consumer_key, consumer_secret):
        self.consumer = oauth.OAuthConsumer(consumer_key,
                                            consumer_secret)

    def set_token(self, token, token_secret):
        self.token = oauth.OAuthToken( token, token_secret)

    def _get_oauth_request_header(self, url, method):
        """Get an oauth request header given the token and the url"""
        query = urlparse(url).query

        oauth_request = oauth.OAuthRequest.from_consumer_and_token(
            http_url=url,
            http_method=method,
            oauth_consumer=self.consumer,
            token=self.token,
            parameters=dict(parse_qsl(query))
        )
        oauth_request.sign_request(oauth.OAuthSignatureMethod_HMAC_SHA1(),
                                   self.consumer, self.token)
        return oauth_request.to_header()

    def request(self, url, method="GET", body=None, headers={}):
        oauth_header = self._get_oauth_request_header(url, method)
        headers.update(oauth_header)
        resp, content = self.client.request(url, method, headers=headers, body=body)

        oops_id = resp.get('x-oops-id', None)
        if oops_id:
            logger.error("== Server Error ==")
            logger.error("Method: %s\nURL: %s\n%s", method, url, oops_id)
            logger.error("==================")

        if resp['status'] == '404':
            raise NotFoundError(url)

        if resp['status'][0] == '5':
            raise InternalError(url)

        if resp['content-type'] == 'application/json':
            content = loads(content)

        return resp, content

def split_args(f):
    def wrapped(self, args):
        args = shlex.split(args)
        f(self, args)
    wrapped.__doc__ = f.__doc__
    return wrapped

class Cache(object):
    """Memory cache"""
    DEFAULT_TIMEOUT = 60
    def __init__(self):
        self.cache = {}

    def set(self, key, value, timeout=DEFAULT_TIMEOUT):
        invalid_after = time.time() + timeout
        self.cache[key] = { 'v': value,
                            'i': invalid_after }

    def get(self, key, default=None):
        if key in self.cache:
            node = self.cache[key]
            if node['i'] < time.time():
                return self.cache['v']
            else:
                del self.cache[key]
        return default

    def drop(self, key):
        if key in self.cache:
            del self.cache[key]

    def clear(self):
        for key in self.cache.keys():
            del self.cache[key]

class Application(cmd.Cmd):
    BASE_URL = 'https://one.ubuntu.com/api/file_storage/v1'
    BASE_CONTENT_URL = 'https://files.one.ubuntu.com'

    def __init__(self):
        cmd.Cmd.__init__(self)
        self.client = OAuthHttpClient()
        self.cwd = ''
        self.finished = False
        # current directory/node info
        # global cache
        self.cache = Cache()
        self.verbose = True

    def format_data(self, value):
        """Utility function to present user-friendly bytes"""
        units = ('b', 'KiB', 'MiB', 'GiB', 'TiB')
        for unit in units:
            if value >= 1024.0:
                value /= 1024.0
            else:
                return "%(value).1f %(unit)s" % {'value': value,
                                                'unit': unit }

    def do_login(self, oauth_info):
        (consumer, consumer_secret, token, token_secret) = \
                        oauth_info.split(':')

        self.client.set_consumer(consumer, consumer_secret)
        self.client.set_token(token, token_secret)

        resp, content = self.client.request(Application.BASE_URL)
        self.cache.set('account-info', content)

        if self.verbose:
            print "User id: %d, name: %s" % (content['user_id'],
                                            content['visible_name'])

            print "Usage: %s/%s" % (self.format_data(content['used_bytes']),
                                    self.format_data(content['max_bytes']))

    def main(self, options):

        if options.oauth:
            self.onecmd('login %s' % (options.oauth))

        self.verbose = options.verbose


        if self.verbose:
            print "Welcome to Ubuntu One!"
        self.prompt = '> '

        if options.file:
            fh = open(options.file)
            commands = fh.readlines()
            fh.close()

            for command in commands:
                command = command.strip()
                if command.startswith('login'):
                    logger.debug("> login *hidden*")
                else:
                    logger.debug("> " + command)
                if command.startswith('#'):
                    continue
                self.onecmd(command)
        else:
            while not self.finished:
                try:
                    self.cmdloop()
                except KeyboardInterrupt:
                    logger.info("Interrupt, use Ctrl+D or quit to exit")
                except Exception, e:
                    logger.warning("Error: %r", e)
        
    def safe_quote(self, value):
        return quote(value, safe='/~').replace(' ', '%20')

    def normalize_path(self, requested):
        cwd_tokens = self.cwd.split('/')

        if requested.endswith('/'):
            requested = requested[:-1]

        if requested.startswith('/'):
            cwd_tokens = []

        requested_tokens = requested.split('/')

        for token in requested_tokens:
            if token == '..':
                try:
                    cwd_tokens.pop()
                except:
                    # Do not care if we have nowhere to go
                    pass
            else:
                cwd_tokens.append(token)
        result = self.safe_quote('/'.join(cwd_tokens))
        return result

    @split_args
    def do_quit(self, values):
        """Exit application"""
        self.finished = True

    @split_args
    def do_cd(self, values):
        """Change remote directory"""
        value = values.pop()

        if value == '/':
            value = ''
        else:
            value = self.normalize_path(value)
            
        resp, content = self.client.request(Application.BASE_URL 
                + value + '?include_children=true')

        self.cache.set(value, content)
        if 'children' in content:
            for child in content['children']:
                resource_path = self.normalize_path(child['resource_path'])
                self.cache.set(resource_path, child)

        self.cwd = content.get('resource_path', '')
        self.prompt = self.cwd + u'> '

    @split_args
    def do_lcd(self, values):
        """Change local directory"""

        value = values.pop()
        try:
            os.chdir(value)
        except Exception, e:
            raise UserError(e)

    @split_args
    def do_delete(self, values):
        """Delete files"""

        for value in values:
            self.delete_file(value)

    def do_ls(self, values):
        """List directory contents"""
        output = []
        max_length = 0
        
        if self.cwd == '':
            # root, no files here
            content = self.cache.get('account-info')
            if not content:
                resp, content = self.client.request(Application.BASE_URL)

            user_node_paths = content['user_node_paths']
            for name in user_node_paths:
                name_len = len(name)
                if name_len > max_length:
                    max_length = name_len
                output.append((name, '-', '-', ''))
        else:
            remote = self.normalize_path(self.cwd)
            content = self.cache.get(remote)
            if not content:
                url =  Application.BASE_URL + \
                        remote + \
                        '?include_children=true'
                resp, content = self.client.request(url)

            node_info = content
            if node_info.get('has_children', False):
                children = node_info['children']
                for child in children:
                    name = child['path'].split('/')[-1]
                    name_len = len(name)
                    if name_len > max_length:
                        max_length = name_len

                    kind = child['kind']
                    size = child.get('size', '')
                    when_changed = child['when_changed']
                    is_public = child.get('is_public', False)
                    public_url = ''
                    if is_public:
                        public_url = child['public_url']

                    if kind == 'directory':
                        name = name + '/'

                    output.append((name, size, when_changed, public_url))

        output_format = '%%-%ds %%12s %%s %%s' % (max_length, )

        for line in output:
            print output_format % line

    @split_args
    def do_put(self, values):
        """Put file"""
        local = values.pop(0)
        remote = os.path.basename(local)

        if len(values):
            remote = values.pop(0)

        remote = self.normalize_path(remote)
        content = self.cache.get(remote)
        if not content:
            try:
                resp, content = self.client.request(Application.BASE_URL + remote)
            except NotFoundError:
                # ok, will put right through
                body = dumps({ 'kind': 'file'})
                headers = { 'content-type': 'application/json' }
                resp, content = self.client.request(Application.BASE_URL + remote,
                        method="PUT", headers=headers, body=body)
                self.cache.set(remote, content)

        remote = self.safe_quote(content['content_path'])
        remote = Application.BASE_CONTENT_URL + remote

        logger.debug("Uploading %s to %s", local, remote)
        fh = open(local, 'r')
        content = fh.read()
        fh.close()

        content = bytearray(content)

        content_type = mimetypes.guess_type(local)[0]

        logger.debug("Content size: %d", len(content))

        if content_type is None:
            content_type = 'application/octet-stream'

        headers = {'Content-Type': content_type,
                   'Content-Length': str(len(content)) }

        resp, content = self.client.request(remote, 
                method="PUT", body=content, headers=headers)

        if resp['status'][0] != '2':
            raise UserError("Could not upload: %s" % (resp['status'],))

    def download_file(self, remote, local):
        """Download a single file"""

        remote = self.normalize_path(remote)
        # Now we  fetch the node info for the path to get the content_path
        # and size
        url = Application.BASE_URL + remote

        content = self.cache.get(remote)
        if content is None:
            resp, content = self.client.request(url)
            if resp['status'] != '200':
                raise UserError("Cannot get metadata for %s: %s" % (
                    remote, resp['status']))

        content_path = content['content_path']
        size = content['size']
        remote = self.safe_quote(content_path)
        remote = Application.BASE_CONTENT_URL + remote

        logger.debug("Downloading %d bytes to %s", size, local)
        resp, content = self.client.request(remote)
        if resp['status'] != '200':
            raise UserError("Cannot download %s: %s" % (
                remote, resp['status'],))
                    
        fh = open(local, 'w')
        fh.write(content)
        fh.close()

    def delete_file(self, remote):
        url = Application.BASE_URL + self.normalize_path(remote)

        # check that remote exists
        content = self.cache.get(remote)
        if content is None:
            resp, content = self.client.request(url)
            if resp['status'] != '200':
                raise UserError("Cannot get metadata for %s: %s" % (
                    remote, resp['status']))

        resp, content = self.client.request(url, method="DELETE")
        if resp['status'] != '200':
            raise UserError("Could not delete %s: %s" % (
                remote, resp['status'],))

        logger.debug("Deleted %s", remote)

    @split_args
    def do_get(self, values):
        """Download file"""
        remote = values.pop(0)
        local = remote.split('/')[-1]

        if len(values):
            local = values.pop(0)

        self.download_file(remote, local)
        logger.debug("Downloaded %s to %s", remote, local)

    @split_args
    def do_mget(self, values):
        """Download multiple files"""
        for remote in values:
            local = remote.split('/')[-1]
            self.download_file(remote, local)
        logger.debug("Finished mget")

    def change_public_status(self, remote, status):
        remote = self.normalize_path(remote)
        url = Application.BASE_URL + remote

        resp, content = self.client.request(url)
        request = {}
        if bool(content['public_url']) == status:
            description = 'unpublished'
            if content['public_url']:
                description = 'published at %s' % (content['public_url'])

            print "Nothing to do: %s" % (description )
            return

        request = dumps({ 'is_public': status })
        resp, content = self.client.request(url, method="PUT", body=request)

        if bool(content['public_url']):
            print "Published at %s" % (content['public_url'], )
        else:
            print "Unpublished"

        self.cache.set(remote, content)

    @split_args
    def do_mkvol(self, values):
        """Create new volume"""
        remote = values.pop(0)
        remote = self.normalize_path(remote)
        url = Application.BASE_URL + '/volumes' + remote + '/'
        print url
        resp, content = self.client.request(url, method="PUT")

    @split_args
    def do_publish(self, values):
        """Publish file"""
        remote = values.pop(0)
        self.change_public_status(remote, True)

    @split_args
    def do_unpublish(self, values):
        """Take down published file"""
        remote = values.pop(0)
        self.change_public_status(remote, False)

    def do_debug(self, command):
        if command == 'cache':
            print "Cache: %d entries" % (len(self.cache.cache))
        if command == 'cache-clear':
            self.cache.clear()
            print "Cleared cache"

    def do_EOF(self, values):
        self.finished = True
        print
        return True

    def emptyline(self):
        pass

if __name__ == "__main__":
    parser = OptionParser('%prog --oauth')
    parser.add_option('--oauth', dest="oauth",
                      help="OAuth tokens")
    parser.add_option('-f', '--file', dest="file",
                      help="File with commands")
    parser.add_option('-d', '--debug', dest="debug",
                      default=False, action="store_true",
                      help="Enable debug output")
    # for cron scripts mostly
    parser.add_option('-q', '--quiet', dest="verbose",
                      default=False, action="store_false",
                      help="Be really quiet")
    options, args = parser.parse_args()
    app = Application()

    if options.debug:
        logger.setLevel(logging.DEBUG)
    app.main(options)


