Skip to content

Commit dc98ad1

Browse files
committed
tried to implememnt handshake and related unit tests
1 parent d03482c commit dc98ad1

12 files changed

Lines changed: 1178 additions & 14 deletions

File tree

.github/workflows/build.yml

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
name: build
2+
3+
on:
4+
push:
5+
branches:
6+
- master
7+
8+
jobs:
9+
build:
10+
runs-on: ubuntu-latest
11+
steps:
12+
- uses: actions/checkout@v1
13+
- name: Setup .NET Core
14+
uses: actions/setup-dotnet@v3
15+
with:
16+
dotnet-version: '9.0.x'
17+
- name: Set env
18+
run: |
19+
echo "DOTNET_CLI_TELEMETRY_OPTOUT=1" >> $GITHUB_ENV
20+
echo "DOTNET_hostBuilder:reloadConfigOnChange=false" >> $GITHUB_ENV
21+
- uses: dotnet/nbgv@master
22+
id: nbgv
23+
- name: Clean
24+
run: |
25+
dotnet clean ./SuperSocket.MySQL.sln --configuration Release
26+
dotnet nuget locals all --clear
27+
- name: Build
28+
run: dotnet build -c Debug
29+
- name: Run MySQL
30+
run: |
31+
cp tests/SuperSocket.MySQL.Test/mysql.cnf ~/.my.cnf
32+
sudo systemctl start mysql.service
33+
mysql -V
34+
- name: Test
35+
run: |
36+
cd tests/SuperSocket.MySQL.Test
37+
dotnet test
38+
- name: Pack
39+
run: dotnet pack -c Release -p:PackageVersion=${{ steps.nbgv.outputs.NuGetPackageVersion }}.${{ github.run_number }} -p:Version=${{ steps.nbgv.outputs.NuGetPackageVersion }}.${{ github.run_number }} -p:AssemblyVersion=${{ steps.nbgv.outputs.AssemblyVersion }} -p:AssemblyFileVersion=${{ steps.nbgv.outputs.AssemblyFileVersion }} -p:AssemblyInformationalVersion=${{ steps.nbgv.outputs.AssemblyInformationalVersion }} /p:NoPackageAnalysis=true /p:IncludeReleaseNotes=false
40+
- name: Push
41+
run: dotnet nuget push **/*.nupkg --api-key ${{ secrets.MYGET_API_KEY }} --source https://www.myget.org/F/supersocket/api/v3/index.json

.travis.yml

Lines changed: 0 additions & 9 deletions
This file was deleted.

src/SuperSocket.MySQL/MySQLConnection.cs

Lines changed: 129 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using System;
22
using System.Net;
33
using System.Net.Sockets;
4+
using System.Security.Cryptography;
5+
using System.Text;
46
using System.Threading;
57
using System.Threading.Tasks;
68
using SuperSocket.Client;
@@ -19,8 +21,14 @@ public class MySQLConnection : EasyClient<MySQLPacket>
1921

2022
private static readonly MySQLPacketEncoder PacketEncoder = new MySQLPacketEncoder();
2123

24+
public bool IsAuthenticated { get; private set; }
25+
2226
public MySQLConnection(string host, int port, string userName, string password)
23-
: this(new MySQLPacketFactory().RegisterPacketType<HandshakeResponsePacket>(0x00))
27+
: this(new MySQLPacketFactory()
28+
.RegisterPacketType<HandshakePacket>(0x0A)
29+
.RegisterPacketType<HandshakeResponsePacket>(0x00)
30+
.RegisterPacketType<OKPacket>(0x00)
31+
.RegisterPacketType<ErrorPacket>(0xFF))
2432
{
2533
_host = host ?? throw new ArgumentNullException(nameof(host));
2634
_port = port > 0 ? port : DefaultPort;
@@ -50,11 +58,127 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
5058

5159
await ConnectAsync(endPoint, cancellationToken).ConfigureAwait(false);
5260

53-
// Send initial handshake packet
54-
var handshakePacket = new HandshakePacket();
55-
await SendAsync(PacketEncoder, handshakePacket).ConfigureAwait(false);
61+
// Wait for server's handshake packet
62+
var packet = await ReceiveAsync().ConfigureAwait(false);
63+
if (!(packet is HandshakePacket handshakePacket))
64+
throw new InvalidOperationException("Expected handshake packet from server.");
65+
66+
// Prepare handshake response
67+
var handshakeResponse = new HandshakeResponsePacket
68+
{
69+
CapabilityFlags = (uint)(ClientCapabilities.CLIENT_PROTOCOL_41 |
70+
ClientCapabilities.CLIENT_SECURE_CONNECTION |
71+
ClientCapabilities.CLIENT_PLUGIN_AUTH |
72+
ClientCapabilities.CLIENT_CONNECT_WITH_DB),
73+
MaxPacketSize = 16777216, // 16MB
74+
CharacterSet = 0x21, // utf8_general_ci
75+
Username = _userName,
76+
Database = string.Empty, // Can be set later if needed
77+
AuthPluginName = "mysql_native_password"
78+
};
79+
80+
// Generate authentication response
81+
handshakeResponse.AuthResponse = GenerateAuthResponse(handshakePacket);
82+
83+
// Send handshake response
84+
await SendAsync(PacketEncoder, handshakeResponse).ConfigureAwait(false);
85+
86+
// Wait for authentication result (OK packet or Error packet)
87+
var authResult = await ReceiveAsync().ConfigureAwait(false);
88+
89+
switch (authResult)
90+
{
91+
case OKPacket okPacket:
92+
// Authentication successful
93+
IsAuthenticated = true;
94+
break;
95+
case ErrorPacket errorPacket:
96+
// Authentication failed
97+
var errorMsg = !string.IsNullOrEmpty(errorPacket.ErrorMessage)
98+
? errorPacket.ErrorMessage
99+
: "Authentication failed";
100+
throw new InvalidOperationException($"MySQL authentication failed: {errorMsg} (Error {errorPacket.ErrorCode})");
101+
default:
102+
throw new InvalidOperationException($"Unexpected packet received during authentication: {authResult?.GetType().Name ?? "null"}");
103+
}
104+
}
105+
106+
private byte[] GenerateAuthResponse(HandshakePacket handshakePacket)
107+
{
108+
if (string.IsNullOrEmpty(_password))
109+
return Array.Empty<byte>();
110+
111+
// Combine auth plugin data parts to form the complete salt
112+
var salt = new byte[20];
113+
handshakePacket.AuthPluginDataPart1?.CopyTo(salt, 0);
114+
if (handshakePacket.AuthPluginDataPart2 != null)
115+
{
116+
var part2Length = Math.Min(handshakePacket.AuthPluginDataPart2.Length, 12);
117+
Array.Copy(handshakePacket.AuthPluginDataPart2, 0, salt, 8, part2Length);
118+
}
119+
120+
// MySQL native password authentication algorithm:
121+
// SHA1(password) XOR SHA1(salt + SHA1(SHA1(password)))
122+
using (var sha1 = SHA1.Create())
123+
{
124+
var passwordBytes = Encoding.UTF8.GetBytes(_password);
125+
var sha1Password = sha1.ComputeHash(passwordBytes);
126+
var sha1Sha1Password = sha1.ComputeHash(sha1Password);
127+
128+
var combined = new byte[salt.Length + sha1Sha1Password.Length];
129+
salt.CopyTo(combined, 0);
130+
sha1Sha1Password.CopyTo(combined, salt.Length);
56131

57-
// Handle authentication to be implemented here
132+
var sha1Combined = sha1.ComputeHash(combined);
133+
134+
var result = new byte[sha1Password.Length];
135+
for (int i = 0; i < sha1Password.Length; i++)
136+
{
137+
result[i] = (byte)(sha1Password[i] ^ sha1Combined[i]);
138+
}
139+
140+
return result;
141+
}
142+
}
143+
144+
/// <summary>
145+
/// Executes a simple query (placeholder implementation)
146+
/// </summary>
147+
/// <param name="query">The SQL query to execute</param>
148+
/// <param name="cancellationToken">Cancellation token</param>
149+
/// <returns>Task representing the async operation</returns>
150+
public async Task<string> ExecuteQueryAsync(string query, CancellationToken cancellationToken = default)
151+
{
152+
if (!IsAuthenticated)
153+
throw new InvalidOperationException("Connection is not authenticated. Call ConnectAsync first.");
154+
155+
if (string.IsNullOrEmpty(query))
156+
throw new ArgumentException("Query cannot be null or empty.", nameof(query));
157+
158+
// This is a placeholder implementation
159+
// In a complete implementation, you would:
160+
// 1. Create a COM_QUERY packet with the SQL query
161+
// 2. Send the packet to the server
162+
// 3. Receive and parse the result set
163+
// 4. Return the results
164+
165+
await Task.Delay(10, cancellationToken); // Simulate async operation
166+
return "Query execution not fully implemented yet";
167+
}
168+
169+
/// <summary>
170+
/// Disconnects from the MySQL server and resets authentication state
171+
/// </summary>
172+
public async Task DisconnectAsync()
173+
{
174+
try
175+
{
176+
await CloseAsync();
177+
}
178+
finally
179+
{
180+
IsAuthenticated = false;
181+
}
58182
}
59183
}
60184
}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
using System.Buffers;
2+
using System.Text;
3+
4+
namespace SuperSocket.MySQL.Packets
5+
{
6+
public class ErrorPacket : MySQLPacket
7+
{
8+
public byte Header { get; set; } = 0xFF; // Error packet identifier
9+
public ushort ErrorCode { get; set; }
10+
public string SqlStateMarker { get; set; } = "#";
11+
public string SqlState { get; set; }
12+
public string ErrorMessage { get; set; }
13+
14+
protected internal override void Decode(ref SequenceReader<byte> reader, object context)
15+
{
16+
// Read header (should be 0xFF for Error packet)
17+
reader.TryRead(out byte header);
18+
Header = header;
19+
20+
// Read error code (2 bytes)
21+
reader.TryReadLittleEndian(out short errorCode);
22+
ErrorCode = (ushort)errorCode;
23+
24+
// Check for SQL state marker and state (optional, depends on capability flags)
25+
if (reader.Remaining >= 6 && reader.UnreadSequence.FirstSpan[0] == (byte)'#')
26+
{
27+
// Read SQL state marker
28+
reader.TryRead(out byte marker);
29+
SqlStateMarker = ((char)marker).ToString();
30+
31+
// Read SQL state (5 characters)
32+
var sqlStateBytes = new byte[5];
33+
reader.TryCopyTo(sqlStateBytes);
34+
reader.Advance(5);
35+
SqlState = Encoding.UTF8.GetString(sqlStateBytes);
36+
}
37+
38+
// Read error message (rest of the packet)
39+
if (reader.Remaining > 0)
40+
{
41+
var messageBytes = new byte[reader.Remaining];
42+
reader.TryCopyTo(messageBytes);
43+
reader.Advance((int)reader.Remaining);
44+
ErrorMessage = Encoding.UTF8.GetString(messageBytes);
45+
}
46+
}
47+
48+
protected internal override int Encode(IBufferWriter<byte> writer)
49+
{
50+
var bytesWritten = 0;
51+
52+
// Write header
53+
bytesWritten += writer.WriteUInt8(Header);
54+
55+
// Write error code
56+
bytesWritten += writer.WriteUInt16(ErrorCode);
57+
58+
// Write SQL state marker and state if present
59+
if (!string.IsNullOrEmpty(SqlState))
60+
{
61+
var markerBytes = Encoding.UTF8.GetBytes(SqlStateMarker ?? "#");
62+
var span = writer.GetSpan(markerBytes.Length);
63+
for (int i = 0; i < markerBytes.Length; i++)
64+
span[i] = markerBytes[i];
65+
writer.Advance(markerBytes.Length);
66+
bytesWritten += markerBytes.Length;
67+
68+
var sqlStateBytes = Encoding.UTF8.GetBytes(SqlState.PadRight(5).Substring(0, 5));
69+
span = writer.GetSpan(5);
70+
for (int i = 0; i < 5; i++)
71+
span[i] = sqlStateBytes[i];
72+
writer.Advance(5);
73+
bytesWritten += 5;
74+
}
75+
76+
// Write error message
77+
if (!string.IsNullOrEmpty(ErrorMessage))
78+
{
79+
var messageBytes = Encoding.UTF8.GetBytes(ErrorMessage);
80+
var span = writer.GetSpan(messageBytes.Length);
81+
for (int i = 0; i < messageBytes.Length; i++)
82+
span[i] = messageBytes[i];
83+
writer.Advance(messageBytes.Length);
84+
bytesWritten += messageBytes.Length;
85+
}
86+
87+
return bytesWritten;
88+
}
89+
}
90+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
using System.Buffers;
2+
3+
namespace SuperSocket.MySQL.Packets
4+
{
5+
public class OKPacket : MySQLPacket
6+
{
7+
public byte Header { get; set; } = 0x00; // OK packet identifier
8+
public ulong AffectedRows { get; set; }
9+
public ulong LastInsertId { get; set; }
10+
public ushort StatusFlags { get; set; }
11+
public ushort Warnings { get; set; }
12+
public string Info { get; set; }
13+
14+
protected internal override void Decode(ref SequenceReader<byte> reader, object context)
15+
{
16+
// Read header (should be 0x00 for OK packet)
17+
reader.TryRead(out byte header);
18+
Header = header;
19+
20+
// Read affected rows (length-encoded integer)
21+
AffectedRows = reader.TryReadLengthEncodedInteger(out long affectedRows) ? (ulong)affectedRows : 0;
22+
23+
// Read last insert ID (length-encoded integer)
24+
LastInsertId = reader.TryReadLengthEncodedInteger(out long lastInsertId) ? (ulong)lastInsertId : 0;
25+
26+
// Read status flags (2 bytes)
27+
reader.TryReadLittleEndian(out short statusFlags);
28+
StatusFlags = (ushort)statusFlags;
29+
30+
// Read warnings (2 bytes)
31+
reader.TryReadLittleEndian(out short warnings);
32+
Warnings = (ushort)warnings;
33+
34+
// Read info string if remaining data
35+
if (reader.Remaining > 0)
36+
{
37+
Info = reader.TryReadLengthEncodedString(out string info) ? info : string.Empty;
38+
}
39+
}
40+
41+
protected internal override int Encode(IBufferWriter<byte> writer)
42+
{
43+
var bytesWritten = 0;
44+
45+
// Write header
46+
bytesWritten += writer.WriteUInt8(Header);
47+
48+
// Write affected rows
49+
bytesWritten += writer.WriteUInt64(AffectedRows);
50+
51+
// Write last insert ID
52+
bytesWritten += writer.WriteUInt64(LastInsertId);
53+
54+
// Write status flags
55+
bytesWritten += writer.WriteUInt16(StatusFlags);
56+
57+
// Write warnings
58+
bytesWritten += writer.WriteUInt16(Warnings);
59+
60+
// Write info string if present
61+
if (!string.IsNullOrEmpty(Info))
62+
{
63+
var infoBytes = System.Text.Encoding.UTF8.GetBytes(Info);
64+
var span = writer.GetSpan(infoBytes.Length);
65+
for (int i = 0; i < infoBytes.Length; i++)
66+
span[i] = infoBytes[i];
67+
writer.Advance(infoBytes.Length);
68+
bytesWritten += infoBytes.Length;
69+
}
70+
71+
return bytesWritten;
72+
}
73+
}
74+
}

0 commit comments

Comments
 (0)