Lock the send method for every tcp mode, not only for tcp_full

This commit is contained in:
Dan 2017-12-18 14:14:44 +01:00
parent dde01cc9b9
commit b23b41bc7d
2 changed files with 9 additions and 10 deletions

View File

@ -17,6 +17,7 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging import logging
import threading
import time import time
from .transport import * from .transport import *
@ -34,6 +35,7 @@ class Connection:
def __init__(self, ipv4: str, mode: int = 1): def __init__(self, ipv4: str, mode: int = 1):
self.address = (ipv4, 80) self.address = (ipv4, 80)
self.mode = self.MODES.get(mode, TCPAbridged) self.mode = self.MODES.get(mode, TCPAbridged)
self.lock = threading.Lock()
self.connection = None self.connection = None
def connect(self): def connect(self):
@ -53,6 +55,7 @@ class Connection:
self.connection.close() self.connection.close()
def send(self, data: bytes): def send(self, data: bytes):
with self.lock:
self.connection.send(data) self.connection.send(data)
def recv(self) -> bytes or None: def recv(self) -> bytes or None:

View File

@ -19,7 +19,6 @@
import logging import logging
from binascii import crc32 from binascii import crc32
from struct import pack, unpack from struct import pack, unpack
from threading import Lock
from .tcp import TCP from .tcp import TCP
@ -30,7 +29,6 @@ class TCPFull(TCP):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.seq_no = None self.seq_no = None
self.lock = Lock()
def connect(self, address: tuple): def connect(self, address: tuple):
super().connect(address) super().connect(address)
@ -38,11 +36,9 @@ class TCPFull(TCP):
log.info("Connected!") log.info("Connected!")
def send(self, data: bytes): def send(self, data: bytes):
with self.lock:
# 12 = packet_length (4), seq_no (4), crc32 (4) (at the end) # 12 = packet_length (4), seq_no (4), crc32 (4) (at the end)
data = pack("<II", len(data) + 12, self.seq_no) + data data = pack("<II", len(data) + 12, self.seq_no) + data
data += pack("<I", crc32(data)) data += pack("<I", crc32(data))
self.seq_no += 1 self.seq_no += 1
super().sendall(data) super().sendall(data)