Skip to content

Commit a724607

Browse files
committed
Merge pull request #299 from YorikSar/conn-speedup
Protocol parsing optimizations
2 parents 6f78fcb + da85c7d commit a724607

2 files changed

Lines changed: 89 additions & 70 deletions

File tree

pymysql/connections.py

Lines changed: 76 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
from .cursors import Cursor
4040
from .constants import CLIENT, COMMAND, FIELD_TYPE, SERVER_STATUS
4141
from .util import byte2int, int2byte
42-
from .converters import escape_item, encoders, decoders, escape_string
42+
from .converters import (
43+
escape_item, encoders, decoders, escape_string, through)
4344
from .err import (
4445
raise_mysql_exception, Warning, Error,
4546
InterfaceError, DataError, DatabaseError, OperationalError,
@@ -102,10 +103,6 @@ def _makefile(sock, mode):
102103
UNSIGNED_SHORT_COLUMN = 252
103104
UNSIGNED_INT24_COLUMN = 253
104105
UNSIGNED_INT64_COLUMN = 254
105-
UNSIGNED_CHAR_LENGTH = 1
106-
UNSIGNED_SHORT_LENGTH = 2
107-
UNSIGNED_INT24_LENGTH = 3
108-
UNSIGNED_INT64_LENGTH = 8
109106

110107
DEFAULT_CHARSET = 'latin1'
111108

@@ -216,18 +213,6 @@ def _hash_password_323(password):
216213
def pack_int24(n):
217214
return struct.pack('<I', n)[:3]
218215

219-
def unpack_uint16(n):
220-
return struct.unpack('<H', n[0:2])[0]
221-
222-
def unpack_int24(n):
223-
return struct.unpack('<I', n + b'\0')[0]
224-
225-
def unpack_int32(n):
226-
return struct.unpack('<I', n)[0]
227-
228-
def unpack_int64(n):
229-
return struct.unpack('<Q', n)[0]
230-
231216

232217
class MysqlPacket(object):
233218
"""Representation of a MySQL response packet.
@@ -291,23 +276,54 @@ def get_bytes(self, position, length=1):
291276
"""
292277
return self._data[position:(position+length)]
293278

279+
if PY2:
280+
def read_uint8(self):
281+
result = ord(self._data[self._position])
282+
self._position += 1
283+
return result
284+
else:
285+
def read_uint8(self):
286+
result = self._data[self._position]
287+
self._position += 1
288+
return result
289+
290+
def read_uint16(self):
291+
result = struct.unpack_from('<H', self._data, self._position)[0]
292+
self._position += 2
293+
return result
294+
295+
def read_uint24(self):
296+
low, high = struct.unpack_from('<HB', self._data, self._position)
297+
self._position += 3
298+
return low + (high << 16)
299+
300+
def read_uint32(self):
301+
result = struct.unpack_from('<I', self._data, self._position)[0]
302+
self._position += 4
303+
return result
304+
305+
def read_uint64(self):
306+
result = struct.unpack_from('<Q', self._data, self._position)[0]
307+
self._position += 8
308+
return result
309+
294310
def read_length_encoded_integer(self):
295311
"""Read a 'Length Coded Binary' number from the data buffer.
296312
297313
Length coded numbers can be anywhere from 1 to 9 bytes depending
298314
on the value of the first byte.
299315
"""
300-
c = ord(self.read(1))
316+
c = self.read_uint8()
301317
if c == NULL_COLUMN:
302318
return None
303319
if c < UNSIGNED_CHAR_COLUMN:
304320
return c
305321
elif c == UNSIGNED_SHORT_COLUMN:
306-
return unpack_uint16(self.read(UNSIGNED_SHORT_LENGTH))
322+
return self.read_uint16()
307323
elif c == UNSIGNED_INT24_COLUMN:
308-
return unpack_int24(self.read(UNSIGNED_INT24_LENGTH))
324+
return self.read_uint24()
309325
elif c == UNSIGNED_INT64_COLUMN:
310-
return unpack_int64(self.read(UNSIGNED_INT64_LENGTH))
326+
return self.read_uint64()
311327

312328
def read_length_coded_string(self):
313329
"""Read a 'Length Coded String' from the data buffer.
@@ -321,6 +337,12 @@ def read_length_coded_string(self):
321337
return None
322338
return self.read(length)
323339

340+
def read_struct(self, fmt):
341+
s = struct.Struct(fmt)
342+
result = s.unpack_from(self._data, self._position)
343+
self._position += s.size
344+
return result
345+
324346
def is_ok_packet(self):
325347
return self._data[0:1] == b'\0'
326348

@@ -344,7 +366,7 @@ def check_error(self):
344366
if self.is_error_packet():
345367
self.rewind()
346368
self.advance(1) # field_count == error (we already know that)
347-
errno = unpack_uint16(self.read(2))
369+
errno = self.read_uint16()
348370
if DEBUG: print("errno =", errno)
349371
raise_mysql_exception(self._data)
350372

@@ -374,13 +396,8 @@ def __parse_field_descriptor(self, encoding):
374396
self.org_table = self.read_length_coded_string().decode(encoding)
375397
self.name = self.read_length_coded_string().decode(encoding)
376398
self.org_name = self.read_length_coded_string().decode(encoding)
377-
self.advance(1) # non-null filler
378-
self.charsetnr = struct.unpack('<H', self.read(2))[0]
379-
self.length = struct.unpack('<I', self.read(4))[0]
380-
self.type_code = byte2int(self.read(1))
381-
self.flags = struct.unpack('<H', self.read(2))[0]
382-
self.scale = byte2int(self.read(1)) # "decimals"
383-
self.advance(2) # filler (always 0x00)
399+
self.charsetnr, self.length, self.type_code, self.flags, self.scale = (
400+
self.read_struct('<xHIBHBxx'))
384401
# 'default' is a length coded binary and is still in the buffer?
385402
# not used for normal result sets...
386403

@@ -424,8 +441,7 @@ def __init__(self, from_packet):
424441

425442
self.affected_rows = self.packet.read_length_encoded_integer()
426443
self.insert_id = self.packet.read_length_encoded_integer()
427-
self.server_status = struct.unpack('<H', self.packet.read(2))[0]
428-
self.warning_count = struct.unpack('<H', self.packet.read(2))[0]
444+
self.server_status, self.warning_count = self.read_struct('<HH')
429445
self.message = self.packet.read_all()
430446
self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
431447

@@ -447,9 +463,7 @@ def __init__(self, from_packet):
447463
self.__class__))
448464

449465
self.packet = from_packet
450-
from_packet.advance(1)
451-
self.warning_count = struct.unpack('<h', from_packet.read(2))[0]
452-
self.server_status = struct.unpack('<h', self.packet.read(2))[0]
466+
self.warning_count, self.server_status = self.packet.read_struct('<xhh')
453467
if DEBUG: print("server_status=", self.server_status)
454468
self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
455469

@@ -635,7 +649,7 @@ def close(self):
635649
''' Send the quit message and close the socket '''
636650
if self.socket is None:
637651
raise Error("Already closed")
638-
send_data = struct.pack('<i', 1) + int2byte(COMMAND.COM_QUIT)
652+
send_data = struct.pack('<iB', 1, COMMAND.COM_QUIT)
639653
try:
640654
self._write_bytes(send_data)
641655
except Exception:
@@ -854,14 +868,9 @@ def _read_packet(self, packet_type=MysqlPacket):
854868
while True:
855869
packet_header = self._read_bytes(4)
856870
if DEBUG: dump_packet(packet_header)
857-
packet_length_bin = packet_header[:3]
858-
871+
btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
872+
bytes_to_read = btrl + (btrh << 16)
859873
#TODO: check sequence id
860-
# packet_number
861-
byte2int(packet_header[3])
862-
863-
bin_length = packet_length_bin + b'\0' # pad little-endian number
864-
bytes_to_read = struct.unpack('<I', bin_length)[0]
865874
recv_data = self._read_bytes(bytes_to_read)
866875
if DEBUG: dump_packet(recv_data)
867876
buff += recv_data
@@ -930,7 +939,7 @@ def _execute_command(self, command, sql):
930939

931940
chunk_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command
932941

933-
prelude = struct.pack('<i', chunk_size) + int2byte(command)
942+
prelude = struct.pack('<iB', chunk_size, command)
934943
self._write_bytes(prelude + sql[:chunk_size-1])
935944
if DEBUG: dump_packet(prelude + sql)
936945

@@ -962,8 +971,7 @@ def _request_authentication(self):
962971
if isinstance(self.user, text_type):
963972
self.user = self.user.encode(self.encoding)
964973

965-
data_init = (struct.pack('<i', self.client_flag) + struct.pack("<I", 1) +
966-
int2byte(charset_id) + int2byte(0)*23)
974+
data_init = struct.pack('<iIB23s', self.client_flag, 1, charset_id, b'')
967975

968976
next_packet = 1
969977

@@ -1202,23 +1210,12 @@ def _read_rowdata_packet(self):
12021210
self.rows = tuple(rows)
12031211

12041212
def _read_row_from_packet(self, packet):
1205-
use_unicode = self.connection.use_unicode
12061213
row = []
1207-
for field in self.fields:
1214+
for encoding, converter in self.converters:
12081215
data = packet.read_length_coded_string()
12091216
if data is not None:
1210-
field_type = field.type_code
1211-
if use_unicode:
1212-
if field_type in TEXT_TYPES:
1213-
charset = charset_by_id(field.charsetnr)
1214-
if use_unicode and not charset.is_binary:
1215-
# TEXTs with charset=binary means BINARY types.
1216-
data = data.decode(charset.encoding)
1217-
else:
1218-
data = data.decode()
1219-
1220-
converter = self.connection.decoders.get(field_type)
1221-
if DEBUG: print("DEBUG: field={}, converter={}".format(field, converter))
1217+
if encoding is not None:
1218+
data = data.decode(encoding)
12221219
if DEBUG: print("DEBUG: DATA = ", data)
12231220
if converter is not None:
12241221
data = converter(data)
@@ -1228,11 +1225,31 @@ def _read_row_from_packet(self, packet):
12281225
def _get_descriptions(self):
12291226
"""Read a column descriptor packet for each column in the result."""
12301227
self.fields = []
1228+
self.converters = []
1229+
use_unicode = self.connection.use_unicode
12311230
description = []
12321231
for i in range_type(self.field_count):
12331232
field = self.connection._read_packet(FieldDescriptorPacket)
12341233
self.fields.append(field)
12351234
description.append(field.description())
1235+
field_type = field.type_code
1236+
if use_unicode:
1237+
if field_type in TEXT_TYPES:
1238+
charset = charset_by_id(field.charsetnr)
1239+
if charset.is_binary:
1240+
# TEXTs with charset=binary means BINARY types.
1241+
encoding = None
1242+
else:
1243+
encoding = charset.encoding
1244+
else:
1245+
encoding = 'ascii'
1246+
else:
1247+
encoding = None
1248+
converter = self.connection.decoders.get(field_type)
1249+
if converter is through:
1250+
converter = None
1251+
if DEBUG: print("DEBUG: field={}, converter={}".format(field, converter))
1252+
self.converters.append((encoding, converter))
12361253

12371254
eof_packet = self.connection._read_packet()
12381255
assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'

pymysql/converters.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -254,13 +254,19 @@ def convert_mysql_timestamp(timestamp):
254254
def convert_set(s):
255255
return set(s.split(","))
256256

257-
def convert_bit(b):
258-
#b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
259-
#return struct.unpack(">Q", b)[0]
260-
#
261-
# the snippet above is right, but MySQLdb doesn't process bits,
262-
# so we shouldn't either
263-
return b
257+
258+
def through(x):
259+
return x
260+
261+
262+
#def convert_bit(b):
263+
# b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
264+
# return struct.unpack(">Q", b)[0]
265+
#
266+
# the snippet above is right, but MySQLdb doesn't process bits,
267+
# so we shouldn't either
268+
convert_bit = through
269+
264270

265271
def convert_characters(connection, field, data):
266272
field_charset = charset_by_id(field.charsetnr).name
@@ -297,10 +303,6 @@ def convert_characters(connection, field, data):
297303
Decimal: str,
298304
}
299305

300-
301-
def through(x):
302-
return x
303-
304306
if not PY2 or JYTHON or IRONPYTHON:
305307
encoders[bytes] = escape_bytes
306308

0 commit comments

Comments
 (0)