diff --git a/pygit2/remote.py b/pygit2/remote.py index 8d2416b44..88e5746ca 100644 --- a/pygit2/remote.py +++ b/pygit2/remote.py @@ -70,7 +70,6 @@ def __init__(self, tp): self.received_bytes = tp.received_bytes """"Number of bytes received up to now""" - class Remote(object): def sideband_progress(self, string): @@ -192,6 +191,7 @@ def fetch(self, signature=None, message=None): callbacks.transfer_progress = self._transfer_progress_cb callbacks.update_tips = self._update_tips_cb callbacks.credentials = self._credentials_cb + callbacks.certificate_check = self._certificate_cb # We need to make sure that this handle stays alive self._self_handle = ffi.new_handle(self) callbacks.payload = self._self_handle @@ -304,6 +304,7 @@ def push(self, specs, signature=None, message=None): callbacks.transfer_progress = self._transfer_progress_cb callbacks.update_tips = self._update_tips_cb callbacks.credentials = self._credentials_cb + callbacks.certificate_check = self._certificate_cb callbacks.push_update_reference = self._push_update_reference_cb # We need to make sure that this handle stays alive self._self_handle = ffi.new_handle(self) @@ -414,6 +415,25 @@ def _credentials_cb(cred_out, url, username, allowed, data): return 0 + @ffi.callback('int (*git_transport_certificate_check_cb)' + '(git_cert *cert, int valid, const char *host, void *payload)') + def _certificate_cb(cert_i, valid, host, data): + self = ffi.from_handle(data) + + if not hasattr(self, 'certificate_check') or not self.certificate_check: + return 0 + + try: + val = self.certificate_cb(None, bool(valid), ffi.string(host)) + if not val: + return C.GIT_ECERTIFICATE + except Exception as e: + self._stored_exception = e + return C.GIT_EUSER + + return 0 + + def get_credentials(fn, url, username, allowed): """Call fn and return the credentials object""" diff --git a/test/test_remote.py b/test/test_remote.py index 751cddcbd..de3004d6d 100644 --- a/test/test_remote.py +++ b/test/test_remote.py @@ -242,6 +242,16 @@ def test_fetch(self): self.assertEqual(stats.indexed_objects, REMOTE_REPO_OBJECTS) self.assertEqual(stats.received_objects, REMOTE_REPO_OBJECTS) + def test_fetch_insecure(self): + def no_check(certificate, valid, host): + return True + remote = self.repo.remotes[0] + remote.certificate_check = no_check + stats = remote.fetch() + self.assertEqual(stats.received_bytes, REMOTE_REPO_BYTES) + self.assertEqual(stats.indexed_objects, REMOTE_REPO_OBJECTS) + self.assertEqual(stats.received_objects, REMOTE_REPO_OBJECTS) + def test_transfer_progress(self): self.tp = None def tp_cb(stats):