-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpayments_nl_sql_agent.py
More file actions
429 lines (352 loc) · 15.9 KB
/
payments_nl_sql_agent.py
File metadata and controls
429 lines (352 loc) · 15.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
# payments_nl_sql_agent.py
#
# This is the brain of the project — takes a plain english (or spanish) question,
# turns it into SQL, runs it, and gives back a human summary.
#
# Pipeline:
# 1. read the DB schema and describe it as text
# 2. send schema + question to LLM → get SQL back
# 3. run the SQL against the DB
# 4. send results back to LLM → get a plain-english summary
#
# The LLM parts are stubs right now — look for "TODO: Replace" to find them.
# Everything else is real and works without an API key.
#
# Run: python payments_nl_sql_agent.py
#
# ---- LESSONS LEARNED ----
# 1. LLMs sometimes wrap SQL in ```sql ... ``` markdown — you have to strip that
# before running it or sqlite3 will just error out with a confusing message
# 2. The schema description you feed the LLM matters A LOT. First version didn't
# include FK info and the model kept writing broken JOINs. Adding FK lines fixed it.
# 3. Prompt injection is a real concern — someone could ask "DROP TABLE payments"
# disguised as a question. Added a regex check so only SELECT gets through.
# 4. Learned that sqlite3.Row lets you access columns by name, not just index —
# much nicer for debugging but you still need to convert to tuple for JSON
# 5. Spanish questions work fine in the prompt as-is, no translation needed.
# The LLM handles it natively.
# -------------------------
import re
import sqlite3
import logging
from typing import Any
DB_PATH = "payments_demo.db"
# basic logging so you can see what's happening step by step
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%H:%M:%S",
)
log = logging.getLogger(__name__)
# ---- helpers -----------------------------------------------------------------
def _extract_question(prompt):
# pull just the user question out of the full prompt string
# the prompt always ends with: Question: "..."
# matching on the whole prompt was a bug — schema text contains words like
# "failed" and "clients" which made the wrong stub branch fire every time
m = re.search(r'Question:\s+"(.+)"', prompt, re.IGNORECASE)
if m:
return m.group(1).lower()
# fallback: shouldn't happen but if the prompt format changes, return full text
return prompt.lower()
# ---- LLM stubs ---------------------------------------------------------------
# These two functions are where the Claude API will eventually go.
# Right now they return hardcoded responses so the whole pipeline can be tested.
# Search for "TODO: Replace" to find the exact spots to change.
def call_llm_for_sql(prompt):
# TODO: Replace this entire function body with a real Claude API call:
#
# import anthropic
# client = anthropic.Anthropic() # picks up ANTHROPIC_API_KEY from env
# resp = client.messages.create(
# model="claude-opus-4-6",
# max_tokens=512,
# messages=[{"role": "user", "content": prompt}],
# )
# return resp.content[0].text.strip()
#
# keeping the stub so the demo runs without spending API credits
log.warning("call_llm_for_sql: stub mode, no real API call")
# this is where I got stuck for a while:
# originally matched keywords on the full prompt string, which already contains
# schema words like "failed" and "clients" — so the wrong branch kept firing.
# fix: extract just the user question from the end of the prompt and match on that.
q = _extract_question(prompt)
# spanish question: failed enterprise volume in january
if "volumen" in q and "fallidos" in q and "enterprise" in q:
return (
"SELECT c.client_name, c.segment, c.country, "
"SUM(p.amount) AS total_failed_amount, "
"COUNT(p.payment_id) AS failed_count "
"FROM payments p "
"JOIN clients c ON c.client_id = p.client_id "
"WHERE p.status = 'failed' "
"AND c.segment = 'Enterprise' "
"AND strftime('%Y-%m', p.date) = '2024-01' "
"GROUP BY c.client_id "
"ORDER BY total_failed_amount DESC"
)
# english question: which clients have the most failures
if "failed" in q and ("client" in q or "highest" in q):
return (
"SELECT c.client_name, c.segment, c.country, "
"COUNT(p.payment_id) AS failed_count, "
"SUM(p.amount) AS total_failed_volume "
"FROM payments p "
"JOIN clients c ON c.client_id = p.client_id "
"WHERE p.status = 'failed' "
"GROUP BY c.client_id "
"ORDER BY failed_count DESC "
"LIMIT 10"
)
# default: total volume by currency for the most recent month in the dataset
# NOTE: can't use date('now','-1 month') here — the dataset is fixed to 2024
# and today is 2026, so that would return zero rows. hardcode '2024-12' instead.
return (
"SELECT currency, "
"COUNT(payment_id) AS payment_count, "
"SUM(amount) AS total_volume "
"FROM payments "
"WHERE strftime('%Y-%m', date) = '2024-12' "
"GROUP BY currency "
"ORDER BY total_volume DESC"
)
def call_llm_for_summary(prompt):
# TODO: Replace this entire function body with a real Claude API call:
#
# import anthropic
# client = anthropic.Anthropic()
# resp = client.messages.create(
# model="claude-opus-4-6",
# max_tokens=256,
# messages=[{"role": "user", "content": prompt}],
# )
# return resp.content[0].text.strip()
log.warning("call_llm_for_summary: stub mode, no real API call")
# placeholder text so you can see where the real summary will appear
return (
"[STUB — replace call_llm_for_summary() with a real Claude API call] "
"Query ran fine and returned the rows above. "
"Wire up the API and this will be a proper business-friendly summary."
)
# ---- schema inspection -------------------------------------------------------
def get_table_names(cur):
# grab only user-created tables — filter out sqlite_sequence and other internal
# tables that SQLite adds automatically. including them confused the LLM and
# added noise to the schema description (discovered this from a code review)
cur.execute(
"SELECT name FROM sqlite_master "
"WHERE type='table' AND name NOT LIKE 'sqlite_%' "
"ORDER BY name;"
)
return [row[0] for row in cur.fetchall()]
def describe_table(cur, table):
# PRAGMA table_info returns: (cid, name, type, notnull, dflt_value, pk)
# not the most intuitive API but it works
cur.execute(f"PRAGMA table_info({table});")
lines = [f"Table: {table}"]
for col in cur.fetchall():
_, col_name, col_type, notnull, _, pk = col
flags = []
if pk: flags.append("PK")
if notnull: flags.append("NOT NULL")
flag_str = f" [{', '.join(flags)}]" if flags else ""
lines.append(f" - {col_name}: {col_type}{flag_str}")
return lines
def describe_foreign_keys(cur, table):
# PRAGMA foreign_key_list: (id, seq, ref_table, from_col, to_col, ...)
# including this in the schema description helps the LLM write correct JOINs
cur.execute(f"PRAGMA foreign_key_list({table});")
lines = []
for fk in cur.fetchall():
lines.append(f" FK: {table}.{fk[3]} -> {fk[2]}.{fk[4]}")
return lines
def get_schema_description():
# reads the actual DB structure and formats it as text for the LLM prompt
# TODO: this is slow for huge databases, might want to cache it later
conn = sqlite3.connect(DB_PATH)
cur = conn.cursor()
tables = get_table_names(cur)
lines = ["SQLite database schema (B2B payments, LatAm focus):\n"]
for table in tables:
lines += describe_table(cur, table)
lines += describe_foreign_keys(cur, table)
lines.append("") # blank line between tables
# adding domain notes here was a big help — LLM used to guess wrong values
lines += [
"Domain notes:",
" - payments.status values: 'completed', 'pending', 'failed'",
" - clients.segment values: 'SMB', 'Mid-Market', 'Enterprise'",
" - clients.country values: CO, MX, AR, BR, CL, PE",
" - payments.currency: COP, MXN, ARS, BRL, CLP, PEN",
" - payments.direction: 'incoming' (money received), 'outgoing' (money sent)",
" - payments.method: 'bank_transfer', 'card', 'wallet'",
" - All date columns use ISO-8601 format (YYYY-MM-DD or YYYY-MM-DD HH:MM:SS)",
" - The database covers payment activity for the full year 2024",
]
conn.close()
return "\n".join(lines)
# ---- SQL generation + safety guardrails --------------------------------------
# these keywords should never appear in a query the LLM gives us
# FIXME: this regex is probably not exhaustive, but covers the obvious ones
_DANGEROUS = re.compile(
r"\b(DROP|DELETE|UPDATE|INSERT|CREATE|ALTER|TRUNCATE|REPLACE|MERGE|ATTACH|DETACH)\b",
re.IGNORECASE,
)
def strip_markdown_fences(raw):
# LLMs often wrap SQL in ```sql ... ``` — strip that before we try to run it
# this tripped me up the first time, sqlite3 error message wasn't helpful at all
cleaned = re.sub(r"```(?:sql)?", "", raw, flags=re.IGNORECASE)
return cleaned.strip().rstrip(";").strip()
def check_sql_is_safe(sql):
# make sure we only ever run SELECT queries — nothing destructive
if not sql.strip().upper().startswith("SELECT"):
raise ValueError(f"generated SQL doesn't start with SELECT, bailing out. got: {sql[:120]}")
hit = _DANGEROUS.search(sql)
if hit:
raise ValueError(f"SQL contains forbidden keyword '{hit.group()}', refusing to run it")
def build_sql_prompt(question, schema):
# the prompt that gets sent to the LLM for SQL generation
# took a few iterations to get the rules right — especially the strftime stuff
return f"""\
You are an expert SQL analyst for a B2B payments company in Latin America.
{schema}
Convert the question below into a single SQL SELECT query for SQLite 3.
Rules:
- Return ONLY the SQL query — no explanations, no markdown, no semicolons.
- Only use tables and columns from the schema above.
- For date filtering use SQLite's strftime() function.
- The dataset covers 2024 only. Do NOT use date('now') — it will return zero rows.
- "last month" or "most recent month" means '2024-12'. Use: strftime('%Y-%m', date) = '2024-12'
- "enero" / "January" means filter on '2024-01'. Map other months similarly (e.g. febrero -> '2024-02').
- Always include a meaningful ORDER BY clause.
- For open-ended "top N" queries, limit to 10 rows unless asked otherwise.
Question: "{question}"
"""
def generate_sql_from_question(question, schema):
# main SQL generation step: prompt → LLM → clean → validate → return
log.info("generating SQL for question...")
prompt = build_sql_prompt(question, schema)
raw_sql = call_llm_for_sql(prompt)
sql = strip_markdown_fences(raw_sql)
check_sql_is_safe(sql)
log.info(f"SQL looks good: {sql[:80]}...")
return sql
# ---- query execution ---------------------------------------------------------
def run_sql_query(sql):
# runs the validated SQL and gives back a list of tuples
# conn.close() is in a finally block so it always runs — even if sqlite throws.
# old version only closed on success, which leaked connections on errors.
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row # lets you access columns by name during debug
try:
cur = conn.cursor()
cur.execute(sql)
rows = [tuple(row) for row in cur.fetchall()]
log.info(f"query returned {len(rows)} row(s)")
return rows
except sqlite3.Error as e:
raise RuntimeError(f"sqlite blew up: {e} -- SQL was: {sql}") from e
finally:
conn.close() # always runs, whether the query succeeded or not
# ---- summarization -----------------------------------------------------------
def build_summary_prompt(question, sql, rows):
# builds the prompt for the "turn rows into a business summary" LLM call
# cap at 50 rows so we don't blow the token limit
preview = rows[:50]
no_data_note = "\n(no rows returned)" if not rows else ""
return f"""\
You are a business analyst at a B2B payments company in Latin America.
The user asked:
"{question}"
The SQL query that ran was:
{sql}
Raw results (max 50 rows shown):
{preview}{no_data_note}
Write a short, plain-english summary in 2-3 sentences.
- Talk like you're explaining to a non-technical manager, not a developer.
- Use actual numbers from the data.
- Do NOT make up numbers that aren't in the rows above.
- If no rows came back, say so and suggest why (filter too narrow, no data for that period, etc.).
- If amounts are in local currencies like COP or MXN, mention it.
"""
def summarize_results(question, sql, rows, schema):
# not using schema in the prompt right now but keeping the param
# in case we want to add it later for context
log.info("sending results to LLM for summary...")
prompt = build_summary_prompt(question, sql, rows)
return call_llm_for_summary(prompt)
# ---- main orchestrator -------------------------------------------------------
def answer_question(question):
# ties everything together: schema -> SQL -> run -> summarize
# returns a dict so the API layer can just serialize it straight to JSON
log.info(f"processing: {question!r}")
result = {
"question": question,
"sql": "",
"rows": [],
"summary": "",
"error": None,
}
try:
schema = get_schema_description()
sql = generate_sql_from_question(question, schema)
rows = run_sql_query(sql)
summary = summarize_results(question, sql, rows, schema)
result["sql"] = sql
result["rows"] = rows
result["summary"] = summary
except ValueError as e:
# this is the guardrail kicking in — unsafe SQL from the LLM
log.error(f"validation problem: {e}")
result["error"] = str(e)
result["summary"] = f"couldn't run that query safely: {e}"
except RuntimeError as e:
# sqlite threw an error
log.error(f"db error: {e}")
result["error"] = str(e)
result["summary"] = f"database error: {e}"
except Exception as e:
# something unexpected — shouldn't happen but let's not crash the API
log.exception("unexpected error, something went wrong")
result["error"] = str(e)
result["summary"] = f"something went wrong: {e}"
return result
# ---- CLI demo ----------------------------------------------------------------
def print_answer(answer):
# formats one answer nicely for the terminal
sep = "-" * 60
print(f"\n{'=' * 60}")
print(f" Q: {answer['question']}")
print(sep)
print(f" SQL:\n {answer['sql']}")
print(sep)
rows = answer["rows"]
if rows:
print(f" rows ({len(rows)} total, first 5):")
for row in rows[:5]:
print(f" {row}")
if len(rows) > 5:
print(f" ... ({len(rows) - 5} more)")
else:
print(" rows: (none)")
print(sep)
print(f" summary:\n {answer['summary']}")
if answer.get("error"):
print(sep)
print(f" error: {answer['error']}")
if __name__ == "__main__":
# three demo questions — one default, one about failures, one in Spanish
demo_questions = [
"What was the total payment volume last month?",
"Which clients have the highest number of failed payments?",
"¿Cuál fue el volumen total de pagos fallidos en enero para clientes enterprise?",
]
print("\n" + "=" * 60)
print(" B2B Payments NL-to-SQL Agent (stub mode)")
print(" swap out call_llm_for_sql and call_llm_for_summary")
print(" with real Claude API calls to go live")
print("=" * 60)
for q in demo_questions:
answer = answer_question(q)
print_answer(answer)