# Copyright (C) 2005 James Henstridge

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

"""Monkey patch OpenSSH support into the Bazaar sftp transport."""

__metaclass__ = type

import sys
import os
import subprocess
import urlparse

from bzrlib.transport.sftp import SFTPTransport

from paramiko.sftp_client import SFTPClient


if 'sftp' not in urlparse.uses_netloc: urlparse.uses_netloc.append('sftp')

class SFTPSubprocess:
    """A socket-like object that talks to an ssh subprocess via pipes."""
    def __init__(self, hostname, port=None, user=None):
        args = ['ssh',
                '-oForwardX11=no', '-oForwardAgent=no',
                '-oClearAllForwardings=yes', '-oProtocol=2',
                '-oNoHostAuthenticationForLocalhost=yes']
        if port is not None:
            args.extend(['-p', str(port)])
        if user is not None:
            args.extend(['-l', user])
        args.extend(['-s', hostname, 'sftp'])
        self.proc = subprocess.Popen(args, close_fds=True,
                                     stdin=subprocess.PIPE,
                                     stdout=subprocess.PIPE)

    def send(self, data):
        return os.write(self.proc.stdin.fileno(), data)

    def recv(self, count):
        return os.read(self.proc.stdout.fileno(), count)

    def close(self):
        self.proc.stdin.close()
        self.proc.stdout.close()
        self.proc.wait()


def _unparse_url(self, path=None):
    if path is None:
        path = self._path
    netloc = self._host
    if self._username is not None:
        netloc = '%s@%s' % (self._username, netloc)
    if self._port is not None:
        netloc = '%s:%d' % (netloc, self._port)

    return urlparse.urlunparse(('sftp', netloc, path, '', '', ''))

def _parse_url(self, url):
    (scheme, netloc, path, params,
     query, fragment) = urlparse.urlparse(url, allow_fragments=False)
    assert scheme == 'sftp'
    if '@' in netloc:
        self._username, self._host = netloc.split('@', 1)
    else:
        self._username = None
        self._host = netloc

    if ':' in self._host:
        self._host, self._port = self._host.rsplit(':', 1)
        self._port = int(self._port)
    else:
        self._port = None

    self._path = path
    if self._path == '':
        self._path = '/'

def _sftp_connect(self):
    sock = SFTPSubprocess(self._host, self._port, self._username)
    self._sftp = SFTPClient(sock)

# monkey patch the standard SFTPTransport ...
SFTPTransport._unparse_url = _unparse_url
SFTPTransport._parse_url = _parse_url
SFTPTransport._sftp_connect = _sftp_connect
