Merge branch 'master' into docs

This commit is contained in:
Dan 2018-01-18 13:23:20 +01:00
commit 84dccf327b
10 changed files with 76 additions and 58 deletions

View File

@ -53,6 +53,7 @@ from pyrogram.session import Auth, Session
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
Config = namedtuple("Config", ["api_id", "api_hash"]) Config = namedtuple("Config", ["api_id", "api_hash"])
Proxy = namedtuple("Proxy", ["enabled", "hostname", "port", "username", "password"])
class Client: class Client:
@ -90,6 +91,7 @@ class Client:
self.markdown = Markdown(self.peers_by_id) self.markdown = Markdown(self.peers_by_id)
self.config = None self.config = None
self.proxy = None
self.session = None self.session = None
self.update_handler = None self.update_handler = None
@ -101,7 +103,7 @@ class Client:
self.load_config() self.load_config()
self.load_session(self.session_name) self.load_session(self.session_name)
self.session = Session(self.dc_id, self.test_mode, self.auth_key, self.config.api_id) self.session = Session(self.dc_id, self.test_mode, self.proxy, self.auth_key, self.config.api_id)
terms = self.session.start() terms = self.session.start()
@ -191,9 +193,9 @@ class Client:
self.session.stop() self.session.stop()
self.dc_id = e.x self.dc_id = e.x
self.auth_key = Auth(self.dc_id, self.test_mode).create() self.auth_key = Auth(self.dc_id, self.test_mode, self.proxy).create()
self.session = Session(self.dc_id, self.test_mode, self.auth_key, self.config.api_id) self.session = Session(self.dc_id, self.test_mode, self.proxy, self.auth_key, self.config.api_id)
self.session.start() self.session.start()
r = self.send( r = self.send(
@ -290,12 +292,21 @@ class Client:
return r.user.id return r.user.id
def load_config(self): def load_config(self):
config = ConfigParser() parser = ConfigParser()
config.read("config.ini") parser.read("config.ini")
self.config = Config( self.config = Config(
int(config["pyrogram"]["api_id"]), api_id=parser.getint("pyrogram", "api_id"),
config["pyrogram"]["api_hash"] api_hash=parser.get("pyrogram", "api_hash")
)
if parser.has_section("proxy"):
self.proxy = Proxy(
enabled=parser.getboolean("proxy", "enabled"),
hostname=parser.get("proxy", "hostname"),
port=parser.getint("proxy", "port"),
username=parser.get("proxy", "username", fallback=None) or None,
password=parser.get("proxy", "password", fallback=None) or None
) )
def load_session(self, session_name): def load_session(self, session_name):
@ -304,7 +315,7 @@ class Client:
s = json.load(f) s = json.load(f)
except FileNotFoundError: except FileNotFoundError:
self.dc_id = 1 self.dc_id = 1
self.auth_key = Auth(self.dc_id, self.test_mode).create() self.auth_key = Auth(self.dc_id, self.test_mode, self.proxy).create()
else: else:
self.dc_id = s["dc_id"] self.dc_id = s["dc_id"]
self.test_mode = s["test_mode"] self.test_mode = s["test_mode"]
@ -1297,7 +1308,7 @@ class Client:
file_id = file_id or self.rnd_id() file_id = file_id or self.rnd_id()
md5_sum = md5() if not is_big and not is_missing_part else None md5_sum = md5() if not is_big and not is_missing_part else None
session = Session(self.dc_id, self.test_mode, self.auth_key, self.config.api_id) session = Session(self.dc_id, self.test_mode, self.proxy, self.auth_key, self.config.api_id)
session.start() session.start()
try: try:
@ -1362,7 +1373,8 @@ class Client:
session = Session( session = Session(
dc_id, dc_id,
self.test_mode, self.test_mode,
Auth(dc_id, self.test_mode).create(), self.proxy,
Auth(dc_id, self.test_mode, self.proxy).create(),
self.config.api_id self.config.api_id
) )
@ -1378,6 +1390,7 @@ class Client:
session = Session( session = Session(
dc_id, dc_id,
self.test_mode, self.test_mode,
self.proxy,
self.auth_key, self.auth_key,
self.config.api_id self.config.api_id
) )
@ -1433,7 +1446,8 @@ class Client:
cdn_session = Session( cdn_session = Session(
r.dc_id, r.dc_id,
self.test_mode, self.test_mode,
Auth(r.dc_id, self.test_mode).create(), self.proxy,
Auth(r.dc_id, self.test_mode, self.proxy).create(),
self.config.api_id, self.config.api_id,
is_cdn=True is_cdn=True
) )

View File

@ -32,15 +32,16 @@ class Connection:
2: TCPIntermediate 2: TCPIntermediate
} }
def __init__(self, ipv4: str, mode: int = 1): def __init__(self, ipv4: str, proxy: type, mode: int = 1):
self.address = (ipv4, 80) self.address = (ipv4, 80)
self.proxy = proxy
self.mode = self.MODES.get(mode, TCPAbridged) self.mode = self.MODES.get(mode, TCPAbridged)
self.lock = threading.Lock() self.lock = threading.Lock()
self.connection = None self.connection = None
def connect(self): def connect(self):
while True: while True:
self.connection = self.mode() self.connection = self.mode(self.proxy)
try: try:
log.info("Connecting...") log.info("Connecting...")
@ -56,7 +57,7 @@ class Connection:
def send(self, data: bytes): def send(self, data: bytes):
with self.lock: with self.lock:
self.connection.send(data) self.connection.sendall(data)
def recv(self) -> bytes or None: def recv(self) -> bytes or None:
return self.connection.recv() return self.connection.recvall()

View File

@ -18,19 +18,30 @@
import logging import logging
import socket import socket
from collections import namedtuple
import socks
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
Proxy = namedtuple("Proxy", ["enabled", "hostname", "port", "username", "password"])
class TCP(socket.socket):
def __init__(self): class TCP(socks.socksocket):
def __init__(self, proxy: Proxy):
super().__init__() super().__init__()
self.proxy_enabled = False
def send(self, *args): if proxy and proxy.enabled:
pass self.proxy_enabled = True
def recv(self, *args): self.set_proxy(
pass proxy_type=socks.SOCKS5,
addr=proxy.hostname,
port=proxy.port,
username=proxy.username,
password=proxy.password
)
def close(self): def close(self):
try: try:

View File

@ -24,16 +24,16 @@ log = logging.getLogger(__name__)
class TCPAbridged(TCP): class TCPAbridged(TCP):
def __init__(self): def __init__(self, proxy: type):
super().__init__() super().__init__(proxy)
self.is_first_packet = None self.is_first_packet = None
def connect(self, address: tuple): def connect(self, address: tuple):
super().connect(address) super().connect(address)
self.is_first_packet = True self.is_first_packet = True
log.info("Connected!") log.info("Connected{}!".format(" with proxy" if self.proxy_enabled else ""))
def send(self, data: bytes): def sendall(self, data: bytes, *args):
length = len(data) // 4 length = len(data) // 4
data = ( data = (
@ -48,20 +48,16 @@ class TCPAbridged(TCP):
super().sendall(data) super().sendall(data)
def recv(self) -> bytes or None: def recvall(self, length: int = 0) -> bytes or None:
length = self.recvall(1) length = super().recvall(1)
if length is None: if length is None:
return None return None
if length == b"\x7f": if length == b"\x7f":
length = self.recvall(3) length = super().recvall(3)
if length is None: if length is None:
return None return None
length = int.from_bytes(length, "little") * 4 return super().recvall(int.from_bytes(length, "little") * 4)
packet = self.recvall(length)
return packet

View File

@ -26,16 +26,16 @@ log = logging.getLogger(__name__)
class TCPFull(TCP): class TCPFull(TCP):
def __init__(self): def __init__(self, proxy: type):
super().__init__() super().__init__(proxy)
self.seq_no = None self.seq_no = None
def connect(self, address: tuple): def connect(self, address: tuple):
super().connect(address) super().connect(address)
self.seq_no = 0 self.seq_no = 0
log.info("Connected!") log.info("Connected{}!".format(" with proxy" if self.proxy_enabled else ""))
def send(self, data: bytes): def sendall(self, data: bytes, *args):
# 12 = packet_length (4), seq_no (4), crc32 (4) (at the end) # 12 = packet_length (4), seq_no (4), crc32 (4) (at the end)
data = pack("<II", len(data) + 12, self.seq_no) + data data = pack("<II", len(data) + 12, self.seq_no) + data
data += pack("<I", crc32(data)) data += pack("<I", crc32(data))
@ -43,13 +43,13 @@ class TCPFull(TCP):
super().sendall(data) super().sendall(data)
def recv(self) -> bytes or None: def recvall(self, length: int = 0) -> bytes or None:
length = self.recvall(4) length = super().recvall(4)
if length is None: if length is None:
return None return None
packet = self.recvall(unpack("<I", length)[0] - 4) packet = super().recvall(unpack("<I", length)[0] - 4)
if packet is None: if packet is None:
return None return None

View File

@ -25,16 +25,16 @@ log = logging.getLogger(__name__)
class TCPIntermediate(TCP): class TCPIntermediate(TCP):
def __init__(self): def __init__(self, proxy: type):
super().__init__() super().__init__(proxy)
self.is_first_packet = None self.is_first_packet = None
def connect(self, address: tuple): def connect(self, address: tuple):
super().connect(address) super().connect(address)
self.is_first_packet = True self.is_first_packet = True
log.info("Connected!") log.info("Connected{}!".format(" with proxy" if self.proxy_enabled else ""))
def send(self, data: bytes): def sendall(self, data: bytes, *args):
length = len(data) length = len(data)
data = pack("<i", length) + data data = pack("<i", length) + data
@ -44,12 +44,10 @@ class TCPIntermediate(TCP):
super().sendall(data) super().sendall(data)
def recv(self) -> bytes or None: def recvall(self, length: int = 0) -> bytes or None:
length = self.recvall(4) length = super().recvall(4)
if length is None: if length is None:
return None return None
packet = self.recvall(unpack("<I", length)[0]) return super().recvall(unpack("<I", length)[0])
return packet

View File

@ -16,10 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
try: from pyaes import AES
from pyaes import AES
except ImportError:
pass
BLOCK_SIZE = 16 BLOCK_SIZE = 16

View File

@ -46,11 +46,11 @@ class Auth:
16 16
) )
def __init__(self, dc_id: int, test_mode: bool): def __init__(self, dc_id: int, test_mode: bool, proxy: type):
self.dc_id = dc_id self.dc_id = dc_id
self.test_mode = test_mode self.test_mode = test_mode
self.connection = Connection(DataCenter(dc_id, test_mode)) self.connection = Connection(DataCenter(dc_id, test_mode), proxy)
self.msg_id = MsgId() self.msg_id = MsgId()
def pack(self, data: Object) -> bytes: def pack(self, data: Object) -> bytes:

View File

@ -68,7 +68,7 @@ class Session:
notice_displayed = False notice_displayed = False
def __init__(self, dc_id: int, test_mode: bool, auth_key: bytes, api_id: str, is_cdn: bool = False): def __init__(self, dc_id: int, test_mode: bool, proxy: type, auth_key: bytes, api_id: str, is_cdn: bool = False):
if not Session.notice_displayed: if not Session.notice_displayed:
print("Pyrogram v{}, {}".format(__version__, __copyright__)) print("Pyrogram v{}, {}".format(__version__, __copyright__))
print("Licensed under the terms of the " + __license__, end="\n\n") print("Licensed under the terms of the " + __license__, end="\n\n")
@ -76,7 +76,7 @@ class Session:
self.is_cdn = is_cdn self.is_cdn = is_cdn
self.connection = Connection(DataCenter(dc_id, test_mode)) self.connection = Connection(DataCenter(dc_id, test_mode), proxy)
self.api_id = api_id self.api_id = api_id

View File

@ -26,5 +26,6 @@ classifiers =
[options] [options]
packages = find: packages = find:
zip_safe = False zip_safe = False
install_requires = pyaes setup_requires = pyaes; pysocks
install_requires = pyaes; pysocks
include_package_data = True include_package_data = True