123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619 |
- import os, json, pymssql, asyncio, autogen, torch
- from datetime import datetime
- import util
- from util import get_db_param, json_to_excel, json_to_dataframe
- from sklearn.metrics.pairwise import cosine_similarity
- from transformers import BertTokenizer, BertModel
- from config import STATIC_DIR, bge_model_path
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
- from agents import data_engineer, sql_answer
- from prompt import POSTGRES_TABLE_DEFINITIONS_CAP_REF, NOTE, EXAMPLE
- from util import plot_data
- plot_instance = plot_data()
- # MultiAgent工具的基类,提供基本的状态和功能管理
- class AgentInstruments:
- """
- Base class for multli-agent instruments that are tools, state, and functions that an agent can use across the lifecycle of conversations
- """
- # 初始化会话ID和消息列表
- def __init__(self) -> None:
- self.session_id = None
- self.messages = []
- # 支持上下文管理器的进入
- def __enter__(self):
- return self
- # 支持上下文管理器的退出
- def __exit__(self, exc_type, exc_value, traceback):
- pass
- # 同步消息与协调者
- def sync_messages(self, messages: list):
- """
- Syncs messages with the orchestrator
- """
- raise NotImplementedError
- # def make_agent_chat_file(self, team_name: str):
- # return os.path.join(self.root_dir, f"agent_chats_{team_name}.json")
- # def make_agent_cost_file(self, team_name: str):
- # return os.path.join(self.root_dir, f"agent_cost_{team_name}.json")
- # 返回根目录路径
- @property
- def root_dir(self):
- return os.path.join(STATIC_DIR, self.session_id)
- # 为PostgreSQL数据分析MultiAgent系统提供统一工具集
- class PostgresAgentInstruments(AgentInstruments):
- """
- Unified Toolset for the Postgres Data Analytics Multi-Agent System
- Advantages:
- - All agents have access to the same state and functions
- - Gives agent functions awareness of changing context
- - Clear and concise capabilities for agents
- - Clean database connection management
- Guidelines:
- - Agent Functions should not call other agent functions directly
- - Instead Agent Functions should call external lower level modules
- - Prefer 1 to 1 mapping of agents and their functions
- - The state lifecycle lives between all agent orchestrations
- """
- # 初始化数据库URL、会话ID和消息列表
- def __init__(self, db_url: str, session_id: str) -> None:
- super().__init__()
- self.db_url = db_url
- self.db = None
- self.session_id = session_id
- self.messages = []
- self.innovation_index = 0
- # 连接数据库并重置文件
- def __enter__(self):
- """
- Support entering the 'with' statement
- """
- self.reset_files()
- self.db = PostgresManager()
- self.db.connect_with_url(self.db_url)
- return self, self.db
- # 关闭数据库连接
- def __exit__(self, exc_type, exc_val, exc_tb):
- """
- Support exiting the 'with' statement
- """
- self.db.close()
- # 同步消息
- def sync_messages(self, messages: list):
- """
- Syncs messages with the orchestrator
- """
- self.messages = messages
- # 清空根目录下所有文件
- def reset_files(self):
- """
- Clear everything in the root_dir
- """
- # if it does not exist create it
- if not os.path.exists(self.root_dir):
- os.makedirs(self.root_dir)
- # for fname in os.listdir(self.root_dir):
- # os.remove(os.path.join(self.root_dir, fname))
- # 获取文件的完整路径
- def get_file_path(self, fname: str):
- """
- Get the full path to a file in the root_dir
- """
- return os.path.join(self.root_dir, fname)
- # -------------------------- Agent Properties -------------------------- #
- # 获取sql_results文件完整路径
- @property
- def run_sql_results_file(self):
- return self.get_file_path("run_sql_results.json")
- # 获取sql_query文件完整路径
- @property
- def sql_query_file(self):
- return self.get_file_path("sql_query.sql")
- # -------------------------- Agent Functions -------------------------- #
- # 执行SQL查询供将结果写入JSON文件
- def run_sql(self, sql: str) -> str:
- """
- Run a SQL query against the postgres database
- """
- results_as_json = self.db.run_sql(sql)
- fname = self.run_sql_results_file
- # dump these results to a file
- with open(fname, "w") as f:
- f.write(results_as_json)
- with open(self.sql_query_file, "w") as f:
- f.write(sql)
- return "Successfully delivered results to json file"
- # 验证SQL结果文件是否存在并有内容
- def validate_run_sql(self):
- """
- validate that the run_sql results file exists and has content
- """
- fname = self.run_sql_results_file
- with open(fname, "r") as f:
- content = f.read()
- if not content:
- return False, f"File {fname} is empty"
- return True, ""
- # 将内容写入文件
- def write_file(self, content: str):
- fname = self.get_file_path(f"write_file.txt")
- return util.write_file(fname, content)
- # 将JSON字符串写入文件
- def write_json_file(self, json_str: str):
- fname = self.get_file_path(f"write_json_file.json")
- return util.write_json_file(fname, json_str)
- # 将JSON字符串转换为YAML格式写入文件
- def write_yml_file(self, json_str: str):
- fname = self.get_file_path(f"write_yml_file.yml")
- return util.write_yml_file(fname, json_str)
- # 写入创建文件并更新索引
- def write_innovation_file(self, content: str):
- fname = self.get_file_path(f"{self.innovation_index}_innovation_file.json")
- util.write_file(fname, content)
- self.innovation_index += 1
- return f"Successfully wrote innovation file. You can check my work."
- # 验证所有创建文件是否存在并有内容
- def validate_innovation_files(self):
- """
- loop from 0 to innovation_index and verify file exists with content
- """
- for i in range(self.innovation_index):
- fname = self.get_file_path(f"{i}_innovation_file.json")
- with open(fname, "r") as f:
- content = f.read()
- if not content:
- return False, f"File {fname} is empty"
- return True, ""
- # 管理PostgreSQL数据库的连接和查询
- class PostgresManager:
- """
- A class to manage postgres connections and queries
- """
- def __init__(self):
- self.conn = None
- self.cur = None
- def __enter__(self):
- return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- if self.cur:
- self.cur.close()
- if self.conn:
- self.conn.close()
- def connect_with_url(self, url):
- self.conn = pymssql.connect(**url)
- self.cur = self.conn.cursor()
- def close(self):
- if self.cur:
- self.cur.close()
- if self.conn:
- self.conn.close()
- def run_sql(self, sql) -> str:
- """
- Run a SQL query against the postgres database
- """
- try:
- self.cur.execute(sql)
- columns = [desc[0] for desc in self.cur.description]
- res = self.cur.fetchall()
- list_of_dicts = [dict(zip(columns, row)) for row in res]
- json_result = json.dumps(list_of_dicts, indent=4, ensure_ascii=False, default=self.datetime_handler)
- return json_result
- except Exception as e:
- return f'Error occurred when execute the sql: {str(e)} Please construct a new SQL query.'
- def datetime_handler(self, obj):
- """
- Handle datetime objects when serializing to JSON.
- """
- if isinstance(obj, datetime):
- return obj.isoformat()
- return str(obj) # or just return the object unchanged, or another default value
- def get_table_definition(self, table_name):
- """
- Generate the 'create' definition for a table
- """
- get_def_stmt = """
- SELECT
- t.name AS tablename,
- c.column_id AS attnum,
- c.name AS attname,
- TYPE_NAME(c.system_type_id) AS data_type
- FROM
- sys.tables t
- JOIN
- sys.columns c ON t.object_id = c.object_id
- WHERE
- t.name = %s -- Assuming @TableName is a parameter
- AND SCHEMA_NAME(t.schema_id) = 'dbo' -- Assuming you're interested in dbo schema
- ORDER BY
- c.column_id;
- """
- self.cur.execute(get_def_stmt, (table_name,))
- rows = self.cur.fetchall()
- create_table_stmt = "CREATE TABLE {} (\n".format(table_name)
- for row in rows:
- create_table_stmt += "{} {},\n".format(row[2], row[3])
- create_table_stmt = create_table_stmt.rstrip(",\n") + "\n);"
- return create_table_stmt
- def get_all_table_names(self):
- """
- Get all table names in the database
- """
- get_all_tables_stmt = (
- "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE';"
- )
- self.cur.execute(get_all_tables_stmt)
- return [row[0] for row in self.cur.fetchall()]
- def get_table_definitions_for_prompt(self):
- """
- Get all table 'create' definitions in the database
- """
- table_names = self.get_all_table_names()
- definitions = []
- for table_name in table_names:
- definitions.append(self.get_table_definition(table_name))
- return "\n\n".join(definitions)
- def get_table_definition_map_for_embeddings(self):
- """
- Creates a map of table names to table definitions
- """
- table_names = self.get_all_table_names()
- definitions = {}
- for table_name in table_names:
- definitions[table_name] = self.get_table_definition(table_name)
- return definitions
- def get_related_tables(self, table_list, n=2):
- """
- Get tables that have foreign keys referencing the given table
- """
- related_tables_dict = {}
- for table in table_list:
- # Query to fetch tables that have foreign keys referencing the given table
- self.cur.execute(
- """
- SELECT
- OBJECT_NAME(fk.parent_object_id) AS table_name
- FROM
- sys.foreign_keys fk
- WHERE
- fk.referenced_object_id = OBJECT_ID(%s)
- ORDER BY
- table_name
- OFFSET 0 ROWS FETCH NEXT %s ROWS ONLY;
- """,
- (table, n),
- )
- related_tables = [row[0] for row in self.cur.fetchall()]
- # Query to fetch tables that the given table references
- self.cur.execute(
- """
- SELECT
- OBJECT_NAME(fk.parent_object_id) AS table_name
- FROM
- sys.foreign_keys fk
- WHERE
- fk.referenced_object_id = OBJECT_ID(%s)
- ORDER BY
- table_name
- OFFSET 0 ROWS FETCH NEXT %s ROWS ONLY;
- """,
- (table, n),
- )
- related_tables += [row[0] for row in self.cur.fetchall()]
- related_tables_dict[table] = related_tables
- # convert dict to list and remove dups
- related_tables_list = []
- for table, related_tables in related_tables_dict.items():
- related_tables_list += related_tables
- related_tables_list = list(set(related_tables_list))
- return related_tables_list
- # 负责嵌入数据库表定义并计算用户查询与表定义之间的相似性
- class DatabaseEmbedder:
- """
- This class is responsible for embedding database table definitions and
- computing similarity between user queries and table definitions.
- """
- def __init__(self, db: PostgresManager, rerank: bool):
- # self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", local_files_only=True)
- # self.model = BertModel.from_pretrained("bert-base-uncased", local_files_only=True)
-
- if rerank:
- self.model = AutoModelForSequenceClassification.from_pretrained(bge_model_path)
- self.tokenizer = AutoTokenizer.from_pretrained(bge_model_path)
- else:
- self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", local_files_only=True)
- self.model = BertModel.from_pretrained("bert-base-uncased", local_files_only=True)
- self.map_name_to_embeddings = {}
- self.map_name_to_table_def = {}
- self.db = db
- self.rerank = rerank
- def get_similar_table_defs_for_prompt(self, prompt: str, n_similar=5, n_foreign=0):
- map_table_name_to_table_def = self.db.get_table_definition_map_for_embeddings()
- for name, table_def in map_table_name_to_table_def.items():
- self.add_table(name, table_def)
- similar_tables = self.get_similar_tables(prompt, n=n_similar)
- table_definitions = self.get_table_definitions_from_names(similar_tables)
- if n_foreign > 0:
- foreign_table_names = self.db.get_foreign_tables(similar_tables, n=3)
- table_definitions = self.get_table_definitions_from_names(
- foreign_table_names + similar_tables
- )
- return table_definitions
- def add_table(self, table_name: str, text_representation: str):
- """
- Add a table to the database embedder.
- Map the table name to its embedding and text representation.
- """
- if self.rerank:
- self.map_name_to_embeddings[table_name] = None
- else:
- self.map_name_to_embeddings[table_name] = self.compute_embeddings(
- text_representation
- )
- self.map_name_to_table_def[table_name] = text_representation
- def compute_embeddings(self, text):
- """
- Compute embeddings for a given text using the BERT model.
- """
- inputs = self.tokenizer(
- text, return_tensors="pt", truncation=True, padding=True, max_length=512
- )
- outputs = self.model(**inputs)
- return outputs["pooler_output"].detach().numpy()
- def get_similar_tables_via_rerank(self,query,n=5):
- self.model.eval()
- with torch.no_grad():
- result = {}
- for tab, tab_def in self.map_name_to_table_def.items():
- inputs_1 = self.tokenizer([[query, tab]], padding=True, truncation=True, return_tensors='pt', max_length=512)
- scores_1 = self.model(**inputs_1, return_dict=True).logits.view(-1, ).float()[0]
- inputs_2 = self.tokenizer([[query, tab_def]], padding=True, truncation=True, return_tensors='pt', max_length=512)
- scores_2 = self.model(**inputs_2, return_dict=True).logits.view(-1, ).float()[0]
- score = 0.7*scores_1 + 0.3*scores_2
- probs = torch.sigmoid(score)
- result[tab] = probs
- print(f'similarity : {result}')
- sorted_results = sorted(result.items(), key=lambda x: x[1], reverse=True)
- final_result = [x[0] for x in sorted_results]
- return final_result[:n]
- def get_similar_tables_via_embeddings(self, query, n=3):
- """
- Given a query, find the top 'n' tables that are most similar to it.
- Args:
- - query (str): The user's natural language query.
- - n (int, optional): Number of top tables to return. Defaults to 3.
- Returns:
- - list: Top 'n' table names ranked by their similarity to the query.
- """
- # Compute the embedding for the user's query
- query_embedding = self.compute_embeddings(query)
- # Calculate cosine similarity between the query and all tables
- similarities = {
- table: cosine_similarity(query_embedding, emb)[0][0]
- for table, emb in self.map_name_to_embeddings.items()
- }
- # Rank tables based on their similarity scores and return top 'n'
- return sorted(similarities, key=similarities.get, reverse=True)[:n]
- def get_similar_table_names_via_word_match(self, query: str):
- """
- if any word in our query is a table name, add the table to a list
- """
- tables = []
- for table_name in self.map_name_to_table_def.keys():
- if table_name.lower() in query.lower():
- tables.append(table_name)
- return tables
- def get_similar_tables(self, query: str, n=3):
- """
- combines results from get_similar_tables_via_embeddings and get_similar_table_names_via_word_match
- """
- if self.rerank:
- similar_tables_via_embeddings = self.get_similar_tables_via_rerank(query, n)
- else:
- similar_tables_via_embeddings = self.get_similar_tables_via_embeddings(query, n)
- similar_tables_via_word_match = self.get_similar_table_names_via_word_match(
- query
- )
- temp_list = similar_tables_via_embeddings + similar_tables_via_word_match
- unique_list = list(dict.fromkeys(temp_list))
- return unique_list
- def get_table_definitions_from_names(self, table_names: list) -> str:
- """
- Given a list of table names, return their table definitions.
- """
- table_defs = [
- self.map_name_to_table_def[table_name] for table_name in table_names
- ]
- return "\n\n".join(table_defs)
- # 处理SQL分析的主要逻辑
- class sql_analyze_father:
- def __init__(self, data_engineer:autogen.AssistantAgent, client_id: str, db_param: dict, table_name=[]) -> None:
- self.sql_generator = data_engineer
- self.db_param = db_param
- self.client_id = client_id
- self.table_name = table_name
- def get_sql(self, content):
- sql = content['content']
- if sql.startswith("SQL query:\n"):
- return sql.split(':')[1].strip()
- elif '```' in sql:
- return sql.split('```')[1].strip('sql')
- else:
- return sql
-
- def add_cap_ref(self,
- prompt: str, prompt_suffix: str, cap_ref: str, cap_ref_content: str, note: str, example: str
- ) -> str:
- new_prompt = f"""{prompt} {prompt_suffix}\n\n{cap_ref}\n\n{cap_ref_content}\n\n{note}\n\n{example}"""
- return new_prompt
- async def run_sql_analyze(self, raw_prompt):
- with PostgresAgentInstruments(self.db_param, self.client_id) as (agent_instruments, db):
- map_table_name_to_table_def = db.get_table_definition_map_for_embeddings()
- database_embedder = DatabaseEmbedder(db, rerank=True)
- for name, table_def in map_table_name_to_table_def.items():
- database_embedder.add_table(name, table_def)
- if not self.table_name or self.table_name==[]:
- similar_tables = database_embedder.get_similar_tables(raw_prompt, n=5)
- print(f'similar_tables {similar_tables}')
- table_definitions = database_embedder.get_table_definitions_from_names(
- similar_tables
- )
- else:
- table_definitions = database_embedder.get_table_definitions_from_names(
- self.table_name
- )
- prompt = f"Please meet the needs of the user: {raw_prompt}, "
- prompt = self.add_cap_ref(
- prompt,
- f"and use these {POSTGRES_TABLE_DEFINITIONS_CAP_REF} to satisfy the database query.Please ensure that SQL has the highest efficiency and conforms to the syntax of the database.",
- POSTGRES_TABLE_DEFINITIONS_CAP_REF,
- table_definitions, NOTE, EXAMPLE
- )
- messages = [{'role': 'user', 'content': prompt}]
- results = '[]'
-
- i = 0
- try:
- while i < 3 and (len(results)==0 or results == '[]' or 'Error occurred' in results ):
- sql_reply = await data_engineer.a_generate_reply(messages=messages)
- sql_reply = sql_reply if isinstance(sql_reply, dict) else {'role':'assistant', 'content':sql_reply}
- sql = self.get_sql(sql_reply)
- if 'I dont know' in sql:
- i +=1
- continue
- messages.append({'role':'user','content': sql})
- results = db.run_sql(sql)
- messages.append({'role':'assistant','content': results})
- i += 1
- print(f'messages before *****{messages}')
- if i == 3 and (len(results)==0 or results == '[]' or 'Error occurred' in results):
- del messages[-6:]
- if 'I dont know' in sql:
- messages.append({'role':'assistant','content':f'根据所提供的问题和表的信息的关联不够, 我无法召回相关的数据'})
- else:
- messages.append({'role':'assistant','content':f'生成sql出现了问题,结果为: {results}'})
- else:
- del messages[-2*i:-2]
- print('\n ---------------- \n')
- print(f'messages after *****{messages}')
-
- except Exception as e:
- print(e)
- data_sql = messages[-1].get('content')
- summary_messages = [{'role':'user','content':raw_prompt}, {'role':'assistant','content':f'生成的sql: \n {sql} \n 执行的数据结果: {data_sql}'}]
- print(summary_messages)
- summary = await sql_answer.a_generate_reply(messages=summary_messages)
- summary_content = summary['content'] if isinstance(summary, dict) else summary
- print(f'final_answer: \n {summary_content}\n')
- return sql, results, summary_content
-
- def make_data(self, input_data, excel_file, plot_file):
- json_to_excel(input_data, excel_file)
- df = json_to_dataframe(input_data)
- print(df)
- plot_data_html = plot_instance.auto_plot(df, plot_file)
- return df, plot_file
- if __name__ == '__main__':
- db_param = get_db_param('sale_database')
- sql_instance = sql_analyze_father(data_engineer=data_engineer, client_id='dalin', db_param=db_param)
- sql, results, summary = asyncio.run(sql_instance.run_sql_analyze(raw_prompt='哪些服装款式在销售中最受欢迎'))
- df, plot_file = sql_instance.make_data(results, './xxx.xlsx', './yyyy.html')
- print(sql,results, summary)
- pass
|