3939from .cursors import Cursor
4040from .constants import CLIENT , COMMAND , FIELD_TYPE , SERVER_STATUS
4141from .util import byte2int , int2byte
42- from .converters import escape_item , encoders , decoders , escape_string
42+ from .converters import (
43+ escape_item , encoders , decoders , escape_string , through )
4344from .err import (
4445 raise_mysql_exception , Warning , Error ,
4546 InterfaceError , DataError , DatabaseError , OperationalError ,
@@ -102,10 +103,6 @@ def _makefile(sock, mode):
102103UNSIGNED_SHORT_COLUMN = 252
103104UNSIGNED_INT24_COLUMN = 253
104105UNSIGNED_INT64_COLUMN = 254
105- UNSIGNED_CHAR_LENGTH = 1
106- UNSIGNED_SHORT_LENGTH = 2
107- UNSIGNED_INT24_LENGTH = 3
108- UNSIGNED_INT64_LENGTH = 8
109106
110107DEFAULT_CHARSET = 'latin1'
111108
@@ -216,18 +213,6 @@ def _hash_password_323(password):
216213def pack_int24 (n ):
217214 return struct .pack ('<I' , n )[:3 ]
218215
219- def unpack_uint16 (n ):
220- return struct .unpack ('<H' , n [0 :2 ])[0 ]
221-
222- def unpack_int24 (n ):
223- return struct .unpack ('<I' , n + b'\0 ' )[0 ]
224-
225- def unpack_int32 (n ):
226- return struct .unpack ('<I' , n )[0 ]
227-
228- def unpack_int64 (n ):
229- return struct .unpack ('<Q' , n )[0 ]
230-
231216
232217class MysqlPacket (object ):
233218 """Representation of a MySQL response packet.
@@ -291,23 +276,54 @@ def get_bytes(self, position, length=1):
291276 """
292277 return self ._data [position :(position + length )]
293278
279+ if PY2 :
280+ def read_uint8 (self ):
281+ result = ord (self ._data [self ._position ])
282+ self ._position += 1
283+ return result
284+ else :
285+ def read_uint8 (self ):
286+ result = self ._data [self ._position ]
287+ self ._position += 1
288+ return result
289+
290+ def read_uint16 (self ):
291+ result = struct .unpack_from ('<H' , self ._data , self ._position )[0 ]
292+ self ._position += 2
293+ return result
294+
295+ def read_uint24 (self ):
296+ low , high = struct .unpack_from ('<HB' , self ._data , self ._position )
297+ self ._position += 3
298+ return low + (high << 16 )
299+
300+ def read_uint32 (self ):
301+ result = struct .unpack_from ('<I' , self ._data , self ._position )[0 ]
302+ self ._position += 4
303+ return result
304+
305+ def read_uint64 (self ):
306+ result = struct .unpack_from ('<Q' , self ._data , self ._position )[0 ]
307+ self ._position += 8
308+ return result
309+
294310 def read_length_encoded_integer (self ):
295311 """Read a 'Length Coded Binary' number from the data buffer.
296312
297313 Length coded numbers can be anywhere from 1 to 9 bytes depending
298314 on the value of the first byte.
299315 """
300- c = ord ( self .read ( 1 ) )
316+ c = self .read_uint8 ( )
301317 if c == NULL_COLUMN :
302318 return None
303319 if c < UNSIGNED_CHAR_COLUMN :
304320 return c
305321 elif c == UNSIGNED_SHORT_COLUMN :
306- return unpack_uint16 ( self .read ( UNSIGNED_SHORT_LENGTH ) )
322+ return self .read_uint16 ( )
307323 elif c == UNSIGNED_INT24_COLUMN :
308- return unpack_int24 ( self .read ( UNSIGNED_INT24_LENGTH ) )
324+ return self .read_uint24 ( )
309325 elif c == UNSIGNED_INT64_COLUMN :
310- return unpack_int64 ( self .read ( UNSIGNED_INT64_LENGTH ) )
326+ return self .read_uint64 ( )
311327
312328 def read_length_coded_string (self ):
313329 """Read a 'Length Coded String' from the data buffer.
@@ -321,6 +337,12 @@ def read_length_coded_string(self):
321337 return None
322338 return self .read (length )
323339
340+ def read_struct (self , fmt ):
341+ s = struct .Struct (fmt )
342+ result = s .unpack_from (self ._data , self ._position )
343+ self ._position += s .size
344+ return result
345+
324346 def is_ok_packet (self ):
325347 return self ._data [0 :1 ] == b'\0 '
326348
@@ -344,7 +366,7 @@ def check_error(self):
344366 if self .is_error_packet ():
345367 self .rewind ()
346368 self .advance (1 ) # field_count == error (we already know that)
347- errno = unpack_uint16 ( self .read ( 2 ) )
369+ errno = self .read_uint16 ( )
348370 if DEBUG : print ("errno =" , errno )
349371 raise_mysql_exception (self ._data )
350372
@@ -374,13 +396,8 @@ def __parse_field_descriptor(self, encoding):
374396 self .org_table = self .read_length_coded_string ().decode (encoding )
375397 self .name = self .read_length_coded_string ().decode (encoding )
376398 self .org_name = self .read_length_coded_string ().decode (encoding )
377- self .advance (1 ) # non-null filler
378- self .charsetnr = struct .unpack ('<H' , self .read (2 ))[0 ]
379- self .length = struct .unpack ('<I' , self .read (4 ))[0 ]
380- self .type_code = byte2int (self .read (1 ))
381- self .flags = struct .unpack ('<H' , self .read (2 ))[0 ]
382- self .scale = byte2int (self .read (1 )) # "decimals"
383- self .advance (2 ) # filler (always 0x00)
399+ self .charsetnr , self .length , self .type_code , self .flags , self .scale = (
400+ self .read_struct ('<xHIBHBxx' ))
384401 # 'default' is a length coded binary and is still in the buffer?
385402 # not used for normal result sets...
386403
@@ -424,8 +441,7 @@ def __init__(self, from_packet):
424441
425442 self .affected_rows = self .packet .read_length_encoded_integer ()
426443 self .insert_id = self .packet .read_length_encoded_integer ()
427- self .server_status = struct .unpack ('<H' , self .packet .read (2 ))[0 ]
428- self .warning_count = struct .unpack ('<H' , self .packet .read (2 ))[0 ]
444+ self .server_status , self .warning_count = self .read_struct ('<HH' )
429445 self .message = self .packet .read_all ()
430446 self .has_next = self .server_status & SERVER_STATUS .SERVER_MORE_RESULTS_EXISTS
431447
@@ -447,9 +463,7 @@ def __init__(self, from_packet):
447463 self .__class__ ))
448464
449465 self .packet = from_packet
450- from_packet .advance (1 )
451- self .warning_count = struct .unpack ('<h' , from_packet .read (2 ))[0 ]
452- self .server_status = struct .unpack ('<h' , self .packet .read (2 ))[0 ]
466+ self .warning_count , self .server_status = self .packet .read_struct ('<xhh' )
453467 if DEBUG : print ("server_status=" , self .server_status )
454468 self .has_next = self .server_status & SERVER_STATUS .SERVER_MORE_RESULTS_EXISTS
455469
@@ -635,7 +649,7 @@ def close(self):
635649 ''' Send the quit message and close the socket '''
636650 if self .socket is None :
637651 raise Error ("Already closed" )
638- send_data = struct .pack ('<i ' , 1 ) + int2byte ( COMMAND .COM_QUIT )
652+ send_data = struct .pack ('<iB ' , 1 , COMMAND .COM_QUIT )
639653 try :
640654 self ._write_bytes (send_data )
641655 except Exception :
@@ -854,14 +868,9 @@ def _read_packet(self, packet_type=MysqlPacket):
854868 while True :
855869 packet_header = self ._read_bytes (4 )
856870 if DEBUG : dump_packet (packet_header )
857- packet_length_bin = packet_header [: 3 ]
858-
871+ btrl , btrh , packet_number = struct . unpack ( '<HBB' , packet_header )
872+ bytes_to_read = btrl + ( btrh << 16 )
859873 #TODO: check sequence id
860- # packet_number
861- byte2int (packet_header [3 ])
862-
863- bin_length = packet_length_bin + b'\0 ' # pad little-endian number
864- bytes_to_read = struct .unpack ('<I' , bin_length )[0 ]
865874 recv_data = self ._read_bytes (bytes_to_read )
866875 if DEBUG : dump_packet (recv_data )
867876 buff += recv_data
@@ -930,7 +939,7 @@ def _execute_command(self, command, sql):
930939
931940 chunk_size = min (MAX_PACKET_LEN , len (sql ) + 1 ) # +1 is for command
932941
933- prelude = struct .pack ('<i ' , chunk_size ) + int2byte ( command )
942+ prelude = struct .pack ('<iB ' , chunk_size , command )
934943 self ._write_bytes (prelude + sql [:chunk_size - 1 ])
935944 if DEBUG : dump_packet (prelude + sql )
936945
@@ -962,8 +971,7 @@ def _request_authentication(self):
962971 if isinstance (self .user , text_type ):
963972 self .user = self .user .encode (self .encoding )
964973
965- data_init = (struct .pack ('<i' , self .client_flag ) + struct .pack ("<I" , 1 ) +
966- int2byte (charset_id ) + int2byte (0 )* 23 )
974+ data_init = struct .pack ('<iIB23s' , self .client_flag , 1 , charset_id , b'' )
967975
968976 next_packet = 1
969977
@@ -1202,23 +1210,12 @@ def _read_rowdata_packet(self):
12021210 self .rows = tuple (rows )
12031211
12041212 def _read_row_from_packet (self , packet ):
1205- use_unicode = self .connection .use_unicode
12061213 row = []
1207- for field in self .fields :
1214+ for encoding , converter in self .converters :
12081215 data = packet .read_length_coded_string ()
12091216 if data is not None :
1210- field_type = field .type_code
1211- if use_unicode :
1212- if field_type in TEXT_TYPES :
1213- charset = charset_by_id (field .charsetnr )
1214- if use_unicode and not charset .is_binary :
1215- # TEXTs with charset=binary means BINARY types.
1216- data = data .decode (charset .encoding )
1217- else :
1218- data = data .decode ()
1219-
1220- converter = self .connection .decoders .get (field_type )
1221- if DEBUG : print ("DEBUG: field={}, converter={}" .format (field , converter ))
1217+ if encoding is not None :
1218+ data = data .decode (encoding )
12221219 if DEBUG : print ("DEBUG: DATA = " , data )
12231220 if converter is not None :
12241221 data = converter (data )
@@ -1228,11 +1225,31 @@ def _read_row_from_packet(self, packet):
12281225 def _get_descriptions (self ):
12291226 """Read a column descriptor packet for each column in the result."""
12301227 self .fields = []
1228+ self .converters = []
1229+ use_unicode = self .connection .use_unicode
12311230 description = []
12321231 for i in range_type (self .field_count ):
12331232 field = self .connection ._read_packet (FieldDescriptorPacket )
12341233 self .fields .append (field )
12351234 description .append (field .description ())
1235+ field_type = field .type_code
1236+ if use_unicode :
1237+ if field_type in TEXT_TYPES :
1238+ charset = charset_by_id (field .charsetnr )
1239+ if charset .is_binary :
1240+ # TEXTs with charset=binary means BINARY types.
1241+ encoding = None
1242+ else :
1243+ encoding = charset .encoding
1244+ else :
1245+ encoding = 'ascii'
1246+ else :
1247+ encoding = None
1248+ converter = self .connection .decoders .get (field_type )
1249+ if converter is through :
1250+ converter = None
1251+ if DEBUG : print ("DEBUG: field={}, converter={}" .format (field , converter ))
1252+ self .converters .append ((encoding , converter ))
12361253
12371254 eof_packet = self .connection ._read_packet ()
12381255 assert eof_packet .is_eof_packet (), 'Protocol error, expecting EOF'
0 commit comments