sql_instruments.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619
  1. import os, json, pymssql, asyncio, autogen, torch
  2. from datetime import datetime
  3. import util
  4. from util import get_db_param, json_to_excel, json_to_dataframe
  5. from sklearn.metrics.pairwise import cosine_similarity
  6. from transformers import BertTokenizer, BertModel
  7. from config import STATIC_DIR, bge_model_path
  8. from transformers import AutoModelForSequenceClassification, AutoTokenizer
  9. from agents import data_engineer, sql_answer
  10. from prompt import POSTGRES_TABLE_DEFINITIONS_CAP_REF, NOTE, EXAMPLE
  11. from util import plot_data
  12. plot_instance = plot_data()
  13. # MultiAgent工具的基类,提供基本的状态和功能管理
  14. class AgentInstruments:
  15. """
  16. Base class for multli-agent instruments that are tools, state, and functions that an agent can use across the lifecycle of conversations
  17. """
  18. # 初始化会话ID和消息列表
  19. def __init__(self) -> None:
  20. self.session_id = None
  21. self.messages = []
  22. # 支持上下文管理器的进入
  23. def __enter__(self):
  24. return self
  25. # 支持上下文管理器的退出
  26. def __exit__(self, exc_type, exc_value, traceback):
  27. pass
  28. # 同步消息与协调者
  29. def sync_messages(self, messages: list):
  30. """
  31. Syncs messages with the orchestrator
  32. """
  33. raise NotImplementedError
  34. # def make_agent_chat_file(self, team_name: str):
  35. # return os.path.join(self.root_dir, f"agent_chats_{team_name}.json")
  36. # def make_agent_cost_file(self, team_name: str):
  37. # return os.path.join(self.root_dir, f"agent_cost_{team_name}.json")
  38. # 返回根目录路径
  39. @property
  40. def root_dir(self):
  41. return os.path.join(STATIC_DIR, self.session_id)
  42. # 为PostgreSQL数据分析MultiAgent系统提供统一工具集
  43. class PostgresAgentInstruments(AgentInstruments):
  44. """
  45. Unified Toolset for the Postgres Data Analytics Multi-Agent System
  46. Advantages:
  47. - All agents have access to the same state and functions
  48. - Gives agent functions awareness of changing context
  49. - Clear and concise capabilities for agents
  50. - Clean database connection management
  51. Guidelines:
  52. - Agent Functions should not call other agent functions directly
  53. - Instead Agent Functions should call external lower level modules
  54. - Prefer 1 to 1 mapping of agents and their functions
  55. - The state lifecycle lives between all agent orchestrations
  56. """
  57. # 初始化数据库URL、会话ID和消息列表
  58. def __init__(self, db_url: str, session_id: str) -> None:
  59. super().__init__()
  60. self.db_url = db_url
  61. self.db = None
  62. self.session_id = session_id
  63. self.messages = []
  64. self.innovation_index = 0
  65. # 连接数据库并重置文件
  66. def __enter__(self):
  67. """
  68. Support entering the 'with' statement
  69. """
  70. self.reset_files()
  71. self.db = PostgresManager()
  72. self.db.connect_with_url(self.db_url)
  73. return self, self.db
  74. # 关闭数据库连接
  75. def __exit__(self, exc_type, exc_val, exc_tb):
  76. """
  77. Support exiting the 'with' statement
  78. """
  79. self.db.close()
  80. # 同步消息
  81. def sync_messages(self, messages: list):
  82. """
  83. Syncs messages with the orchestrator
  84. """
  85. self.messages = messages
  86. # 清空根目录下所有文件
  87. def reset_files(self):
  88. """
  89. Clear everything in the root_dir
  90. """
  91. # if it does not exist create it
  92. if not os.path.exists(self.root_dir):
  93. os.makedirs(self.root_dir)
  94. # for fname in os.listdir(self.root_dir):
  95. # os.remove(os.path.join(self.root_dir, fname))
  96. # 获取文件的完整路径
  97. def get_file_path(self, fname: str):
  98. """
  99. Get the full path to a file in the root_dir
  100. """
  101. return os.path.join(self.root_dir, fname)
  102. # -------------------------- Agent Properties -------------------------- #
  103. # 获取sql_results文件完整路径
  104. @property
  105. def run_sql_results_file(self):
  106. return self.get_file_path("run_sql_results.json")
  107. # 获取sql_query文件完整路径
  108. @property
  109. def sql_query_file(self):
  110. return self.get_file_path("sql_query.sql")
  111. # -------------------------- Agent Functions -------------------------- #
  112. # 执行SQL查询供将结果写入JSON文件
  113. def run_sql(self, sql: str) -> str:
  114. """
  115. Run a SQL query against the postgres database
  116. """
  117. results_as_json = self.db.run_sql(sql)
  118. fname = self.run_sql_results_file
  119. # dump these results to a file
  120. with open(fname, "w") as f:
  121. f.write(results_as_json)
  122. with open(self.sql_query_file, "w") as f:
  123. f.write(sql)
  124. return "Successfully delivered results to json file"
  125. # 验证SQL结果文件是否存在并有内容
  126. def validate_run_sql(self):
  127. """
  128. validate that the run_sql results file exists and has content
  129. """
  130. fname = self.run_sql_results_file
  131. with open(fname, "r") as f:
  132. content = f.read()
  133. if not content:
  134. return False, f"File {fname} is empty"
  135. return True, ""
  136. # 将内容写入文件
  137. def write_file(self, content: str):
  138. fname = self.get_file_path(f"write_file.txt")
  139. return util.write_file(fname, content)
  140. # 将JSON字符串写入文件
  141. def write_json_file(self, json_str: str):
  142. fname = self.get_file_path(f"write_json_file.json")
  143. return util.write_json_file(fname, json_str)
  144. # 将JSON字符串转换为YAML格式写入文件
  145. def write_yml_file(self, json_str: str):
  146. fname = self.get_file_path(f"write_yml_file.yml")
  147. return util.write_yml_file(fname, json_str)
  148. # 写入创建文件并更新索引
  149. def write_innovation_file(self, content: str):
  150. fname = self.get_file_path(f"{self.innovation_index}_innovation_file.json")
  151. util.write_file(fname, content)
  152. self.innovation_index += 1
  153. return f"Successfully wrote innovation file. You can check my work."
  154. # 验证所有创建文件是否存在并有内容
  155. def validate_innovation_files(self):
  156. """
  157. loop from 0 to innovation_index and verify file exists with content
  158. """
  159. for i in range(self.innovation_index):
  160. fname = self.get_file_path(f"{i}_innovation_file.json")
  161. with open(fname, "r") as f:
  162. content = f.read()
  163. if not content:
  164. return False, f"File {fname} is empty"
  165. return True, ""
  166. # 管理PostgreSQL数据库的连接和查询
  167. class PostgresManager:
  168. """
  169. A class to manage postgres connections and queries
  170. """
  171. def __init__(self):
  172. self.conn = None
  173. self.cur = None
  174. def __enter__(self):
  175. return self
  176. def __exit__(self, exc_type, exc_val, exc_tb):
  177. if self.cur:
  178. self.cur.close()
  179. if self.conn:
  180. self.conn.close()
  181. def connect_with_url(self, url):
  182. self.conn = pymssql.connect(**url)
  183. self.cur = self.conn.cursor()
  184. def close(self):
  185. if self.cur:
  186. self.cur.close()
  187. if self.conn:
  188. self.conn.close()
  189. def run_sql(self, sql) -> str:
  190. """
  191. Run a SQL query against the postgres database
  192. """
  193. try:
  194. self.cur.execute(sql)
  195. columns = [desc[0] for desc in self.cur.description]
  196. res = self.cur.fetchall()
  197. list_of_dicts = [dict(zip(columns, row)) for row in res]
  198. json_result = json.dumps(list_of_dicts, indent=4, ensure_ascii=False, default=self.datetime_handler)
  199. return json_result
  200. except Exception as e:
  201. return f'Error occurred when execute the sql: {str(e)} Please construct a new SQL query.'
  202. def datetime_handler(self, obj):
  203. """
  204. Handle datetime objects when serializing to JSON.
  205. """
  206. if isinstance(obj, datetime):
  207. return obj.isoformat()
  208. return str(obj) # or just return the object unchanged, or another default value
  209. def get_table_definition(self, table_name):
  210. """
  211. Generate the 'create' definition for a table
  212. """
  213. get_def_stmt = """
  214. SELECT
  215. t.name AS tablename,
  216. c.column_id AS attnum,
  217. c.name AS attname,
  218. TYPE_NAME(c.system_type_id) AS data_type
  219. FROM
  220. sys.tables t
  221. JOIN
  222. sys.columns c ON t.object_id = c.object_id
  223. WHERE
  224. t.name = %s -- Assuming @TableName is a parameter
  225. AND SCHEMA_NAME(t.schema_id) = 'dbo' -- Assuming you're interested in dbo schema
  226. ORDER BY
  227. c.column_id;
  228. """
  229. self.cur.execute(get_def_stmt, (table_name,))
  230. rows = self.cur.fetchall()
  231. create_table_stmt = "CREATE TABLE {} (\n".format(table_name)
  232. for row in rows:
  233. create_table_stmt += "{} {},\n".format(row[2], row[3])
  234. create_table_stmt = create_table_stmt.rstrip(",\n") + "\n);"
  235. return create_table_stmt
  236. def get_all_table_names(self):
  237. """
  238. Get all table names in the database
  239. """
  240. get_all_tables_stmt = (
  241. "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE';"
  242. )
  243. self.cur.execute(get_all_tables_stmt)
  244. return [row[0] for row in self.cur.fetchall()]
  245. def get_table_definitions_for_prompt(self):
  246. """
  247. Get all table 'create' definitions in the database
  248. """
  249. table_names = self.get_all_table_names()
  250. definitions = []
  251. for table_name in table_names:
  252. definitions.append(self.get_table_definition(table_name))
  253. return "\n\n".join(definitions)
  254. def get_table_definition_map_for_embeddings(self):
  255. """
  256. Creates a map of table names to table definitions
  257. """
  258. table_names = self.get_all_table_names()
  259. definitions = {}
  260. for table_name in table_names:
  261. definitions[table_name] = self.get_table_definition(table_name)
  262. return definitions
  263. def get_related_tables(self, table_list, n=2):
  264. """
  265. Get tables that have foreign keys referencing the given table
  266. """
  267. related_tables_dict = {}
  268. for table in table_list:
  269. # Query to fetch tables that have foreign keys referencing the given table
  270. self.cur.execute(
  271. """
  272. SELECT
  273. OBJECT_NAME(fk.parent_object_id) AS table_name
  274. FROM
  275. sys.foreign_keys fk
  276. WHERE
  277. fk.referenced_object_id = OBJECT_ID(%s)
  278. ORDER BY
  279. table_name
  280. OFFSET 0 ROWS FETCH NEXT %s ROWS ONLY;
  281. """,
  282. (table, n),
  283. )
  284. related_tables = [row[0] for row in self.cur.fetchall()]
  285. # Query to fetch tables that the given table references
  286. self.cur.execute(
  287. """
  288. SELECT
  289. OBJECT_NAME(fk.parent_object_id) AS table_name
  290. FROM
  291. sys.foreign_keys fk
  292. WHERE
  293. fk.referenced_object_id = OBJECT_ID(%s)
  294. ORDER BY
  295. table_name
  296. OFFSET 0 ROWS FETCH NEXT %s ROWS ONLY;
  297. """,
  298. (table, n),
  299. )
  300. related_tables += [row[0] for row in self.cur.fetchall()]
  301. related_tables_dict[table] = related_tables
  302. # convert dict to list and remove dups
  303. related_tables_list = []
  304. for table, related_tables in related_tables_dict.items():
  305. related_tables_list += related_tables
  306. related_tables_list = list(set(related_tables_list))
  307. return related_tables_list
  308. # 负责嵌入数据库表定义并计算用户查询与表定义之间的相似性
  309. class DatabaseEmbedder:
  310. """
  311. This class is responsible for embedding database table definitions and
  312. computing similarity between user queries and table definitions.
  313. """
  314. def __init__(self, db: PostgresManager, rerank: bool):
  315. # self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", local_files_only=True)
  316. # self.model = BertModel.from_pretrained("bert-base-uncased", local_files_only=True)
  317. if rerank:
  318. self.model = AutoModelForSequenceClassification.from_pretrained(bge_model_path)
  319. self.tokenizer = AutoTokenizer.from_pretrained(bge_model_path)
  320. else:
  321. self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", local_files_only=True)
  322. self.model = BertModel.from_pretrained("bert-base-uncased", local_files_only=True)
  323. self.map_name_to_embeddings = {}
  324. self.map_name_to_table_def = {}
  325. self.db = db
  326. self.rerank = rerank
  327. def get_similar_table_defs_for_prompt(self, prompt: str, n_similar=5, n_foreign=0):
  328. map_table_name_to_table_def = self.db.get_table_definition_map_for_embeddings()
  329. for name, table_def in map_table_name_to_table_def.items():
  330. self.add_table(name, table_def)
  331. similar_tables = self.get_similar_tables(prompt, n=n_similar)
  332. table_definitions = self.get_table_definitions_from_names(similar_tables)
  333. if n_foreign > 0:
  334. foreign_table_names = self.db.get_foreign_tables(similar_tables, n=3)
  335. table_definitions = self.get_table_definitions_from_names(
  336. foreign_table_names + similar_tables
  337. )
  338. return table_definitions
  339. def add_table(self, table_name: str, text_representation: str):
  340. """
  341. Add a table to the database embedder.
  342. Map the table name to its embedding and text representation.
  343. """
  344. if self.rerank:
  345. self.map_name_to_embeddings[table_name] = None
  346. else:
  347. self.map_name_to_embeddings[table_name] = self.compute_embeddings(
  348. text_representation
  349. )
  350. self.map_name_to_table_def[table_name] = text_representation
  351. def compute_embeddings(self, text):
  352. """
  353. Compute embeddings for a given text using the BERT model.
  354. """
  355. inputs = self.tokenizer(
  356. text, return_tensors="pt", truncation=True, padding=True, max_length=512
  357. )
  358. outputs = self.model(**inputs)
  359. return outputs["pooler_output"].detach().numpy()
  360. def get_similar_tables_via_rerank(self,query,n=5):
  361. self.model.eval()
  362. with torch.no_grad():
  363. result = {}
  364. for tab, tab_def in self.map_name_to_table_def.items():
  365. inputs_1 = self.tokenizer([[query, tab]], padding=True, truncation=True, return_tensors='pt', max_length=512)
  366. scores_1 = self.model(**inputs_1, return_dict=True).logits.view(-1, ).float()[0]
  367. inputs_2 = self.tokenizer([[query, tab_def]], padding=True, truncation=True, return_tensors='pt', max_length=512)
  368. scores_2 = self.model(**inputs_2, return_dict=True).logits.view(-1, ).float()[0]
  369. score = 0.7*scores_1 + 0.3*scores_2
  370. probs = torch.sigmoid(score)
  371. result[tab] = probs
  372. print(f'similarity : {result}')
  373. sorted_results = sorted(result.items(), key=lambda x: x[1], reverse=True)
  374. final_result = [x[0] for x in sorted_results]
  375. return final_result[:n]
  376. def get_similar_tables_via_embeddings(self, query, n=3):
  377. """
  378. Given a query, find the top 'n' tables that are most similar to it.
  379. Args:
  380. - query (str): The user's natural language query.
  381. - n (int, optional): Number of top tables to return. Defaults to 3.
  382. Returns:
  383. - list: Top 'n' table names ranked by their similarity to the query.
  384. """
  385. # Compute the embedding for the user's query
  386. query_embedding = self.compute_embeddings(query)
  387. # Calculate cosine similarity between the query and all tables
  388. similarities = {
  389. table: cosine_similarity(query_embedding, emb)[0][0]
  390. for table, emb in self.map_name_to_embeddings.items()
  391. }
  392. # Rank tables based on their similarity scores and return top 'n'
  393. return sorted(similarities, key=similarities.get, reverse=True)[:n]
  394. def get_similar_table_names_via_word_match(self, query: str):
  395. """
  396. if any word in our query is a table name, add the table to a list
  397. """
  398. tables = []
  399. for table_name in self.map_name_to_table_def.keys():
  400. if table_name.lower() in query.lower():
  401. tables.append(table_name)
  402. return tables
  403. def get_similar_tables(self, query: str, n=3):
  404. """
  405. combines results from get_similar_tables_via_embeddings and get_similar_table_names_via_word_match
  406. """
  407. if self.rerank:
  408. similar_tables_via_embeddings = self.get_similar_tables_via_rerank(query, n)
  409. else:
  410. similar_tables_via_embeddings = self.get_similar_tables_via_embeddings(query, n)
  411. similar_tables_via_word_match = self.get_similar_table_names_via_word_match(
  412. query
  413. )
  414. temp_list = similar_tables_via_embeddings + similar_tables_via_word_match
  415. unique_list = list(dict.fromkeys(temp_list))
  416. return unique_list
  417. def get_table_definitions_from_names(self, table_names: list) -> str:
  418. """
  419. Given a list of table names, return their table definitions.
  420. """
  421. table_defs = [
  422. self.map_name_to_table_def[table_name] for table_name in table_names
  423. ]
  424. return "\n\n".join(table_defs)
  425. # 处理SQL分析的主要逻辑
  426. class sql_analyze_father:
  427. def __init__(self, data_engineer:autogen.AssistantAgent, client_id: str, db_param: dict, table_name=[]) -> None:
  428. self.sql_generator = data_engineer
  429. self.db_param = db_param
  430. self.client_id = client_id
  431. self.table_name = table_name
  432. def get_sql(self, content):
  433. sql = content['content']
  434. if sql.startswith("SQL query:\n"):
  435. return sql.split(':')[1].strip()
  436. elif '```' in sql:
  437. return sql.split('```')[1].strip('sql')
  438. else:
  439. return sql
  440. def add_cap_ref(self,
  441. prompt: str, prompt_suffix: str, cap_ref: str, cap_ref_content: str, note: str, example: str
  442. ) -> str:
  443. new_prompt = f"""{prompt} {prompt_suffix}\n\n{cap_ref}\n\n{cap_ref_content}\n\n{note}\n\n{example}"""
  444. return new_prompt
  445. async def run_sql_analyze(self, raw_prompt):
  446. with PostgresAgentInstruments(self.db_param, self.client_id) as (agent_instruments, db):
  447. map_table_name_to_table_def = db.get_table_definition_map_for_embeddings()
  448. database_embedder = DatabaseEmbedder(db, rerank=True)
  449. for name, table_def in map_table_name_to_table_def.items():
  450. database_embedder.add_table(name, table_def)
  451. if not self.table_name or self.table_name==[]:
  452. similar_tables = database_embedder.get_similar_tables(raw_prompt, n=5)
  453. print(f'similar_tables {similar_tables}')
  454. table_definitions = database_embedder.get_table_definitions_from_names(
  455. similar_tables
  456. )
  457. else:
  458. table_definitions = database_embedder.get_table_definitions_from_names(
  459. self.table_name
  460. )
  461. prompt = f"Please meet the needs of the user: {raw_prompt}, "
  462. prompt = self.add_cap_ref(
  463. prompt,
  464. f"and use these {POSTGRES_TABLE_DEFINITIONS_CAP_REF} to satisfy the database query.Please ensure that SQL has the highest efficiency and conforms to the syntax of the database.",
  465. POSTGRES_TABLE_DEFINITIONS_CAP_REF,
  466. table_definitions, NOTE, EXAMPLE
  467. )
  468. messages = [{'role': 'user', 'content': prompt}]
  469. results = '[]'
  470. i = 0
  471. try:
  472. while i < 3 and (len(results)==0 or results == '[]' or 'Error occurred' in results ):
  473. sql_reply = await data_engineer.a_generate_reply(messages=messages)
  474. sql_reply = sql_reply if isinstance(sql_reply, dict) else {'role':'assistant', 'content':sql_reply}
  475. sql = self.get_sql(sql_reply)
  476. if 'I dont know' in sql:
  477. i +=1
  478. continue
  479. messages.append({'role':'user','content': sql})
  480. results = db.run_sql(sql)
  481. messages.append({'role':'assistant','content': results})
  482. i += 1
  483. print(f'messages before *****{messages}')
  484. if i == 3 and (len(results)==0 or results == '[]' or 'Error occurred' in results):
  485. del messages[-6:]
  486. if 'I dont know' in sql:
  487. messages.append({'role':'assistant','content':f'根据所提供的问题和表的信息的关联不够, 我无法召回相关的数据'})
  488. else:
  489. messages.append({'role':'assistant','content':f'生成sql出现了问题,结果为: {results}'})
  490. else:
  491. del messages[-2*i:-2]
  492. print('\n ---------------- \n')
  493. print(f'messages after *****{messages}')
  494. except Exception as e:
  495. print(e)
  496. data_sql = messages[-1].get('content')
  497. summary_messages = [{'role':'user','content':raw_prompt}, {'role':'assistant','content':f'生成的sql: \n {sql} \n 执行的数据结果: {data_sql}'}]
  498. print(summary_messages)
  499. summary = await sql_answer.a_generate_reply(messages=summary_messages)
  500. summary_content = summary['content'] if isinstance(summary, dict) else summary
  501. print(f'final_answer: \n {summary_content}\n')
  502. return sql, results, summary_content
  503. def make_data(self, input_data, excel_file, plot_file):
  504. json_to_excel(input_data, excel_file)
  505. df = json_to_dataframe(input_data)
  506. print(df)
  507. plot_data_html = plot_instance.auto_plot(df, plot_file)
  508. return df, plot_file
  509. if __name__ == '__main__':
  510. db_param = get_db_param('sale_database')
  511. sql_instance = sql_analyze_father(data_engineer=data_engineer, client_id='dalin', db_param=db_param)
  512. sql, results, summary = asyncio.run(sql_instance.run_sql_analyze(raw_prompt='哪些服装款式在销售中最受欢迎'))
  513. df, plot_file = sql_instance.make_data(results, './xxx.xlsx', './yyyy.html')
  514. print(sql,results, summary)
  515. pass