# basic import from datetime import datetime import uvicorn, json, os, uuid, docker, pymssql, autogen from autogen import ConversableAgent from copy import deepcopy from config import llm_config, file_url, BASE_UPLOAD_DIRECTORY, STATIC_DIR, db_list_2, db_cn_map, db_en_map # api server import from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.staticfiles import StaticFiles from fastapi import FastAPI, File, UploadFile, Form from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware # functioncall import from agents import data_engineer, detect_analyze_agent from tools import validate_use_tools, generate_result from data_processor.data_processor import data_processor # sql import from sql_instruments import sql_analyze_father from util import get_timestamp, get_db_param # code import from code_instruments import code_analyze_father # asr import from config import stt_model, port from funasr.utils.postprocess_utils import rich_transcription_postprocess docker_client = docker.from_env() app = FastAPI() if not os.path.exists(BASE_UPLOAD_DIRECTORY): os.makedirs(BASE_UPLOAD_DIRECTORY) os.makedirs(STATIC_DIR, exist_ok=True) app.mount(STATIC_DIR, StaticFiles(directory=BASE_UPLOAD_DIRECTORY), name="static") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) code_instance_map = {} sql_instance_map = {} status_map = {} class sql_analyze(sql_analyze_father): def __init__(self, data_engineer:autogen.AssistantAgent, client_id: str, db_param: dict, table_name=[], file_path=STATIC_DIR) -> None: super().__init__(data_engineer, client_id, db_param, table_name) self.file_path = file_path async def ws_send_data(self, ws:WebSocket, raw_prompt, file_url): self.upload_file_path = os.path.join(BASE_UPLOAD_DIRECTORY, self.client_id) if not os.path.exists(self.upload_file_path): os.makedirs(self.upload_file_path) os.chmod(self.upload_file_path, 0o777) # 设置用户目录权限为777 self.run_sql_result_file_excel = os.path.join(self.upload_file_path, get_timestamp() + '_result.xlsx') self.plot_data_file = os.path.join(self.upload_file_path, get_timestamp() + '_result_plot.html') sql, results, summary_content = await self.run_sql_analyze(raw_prompt) if len(results)==0 or results == '[]' or 'Error occurred' in results: await ws.send_json({'text': summary_content,'files': []}) await ws.send_text('end') else: data = [] self.make_data(results, self.run_sql_result_file_excel, self.plot_data_file) data.append(f'{file_url}' + self.file_path + '/' + self.client_id + '/' + self.run_sql_result_file_excel.split('/')[-1]) data.append(f'{file_url}' + self.file_path + '/' + self.client_id + '/' + self.plot_data_file.split('/')[-1]) print(f'data_send: \n {data}\n') await ws.send_json({'text': summary_content,'files': []}) msg = { 'text': '', 'files': data, } await ws.send_json(msg) await ws.send_text('end') class code_analyze(code_analyze_father): def __init__(self, client_id: str, history: list, files: list) -> None: super().__init__(client_id, history, files) async def ws_send_data(self, ws: WebSocket, prompt): summary, data_all_send = await self.run_seperate_jupyter_auto(prompt=prompt) await ws.send_json({'text': summary,'files': []}) await ws.send_json({'text': '','files': data_all_send}) await ws.send_text('end') @app.websocket("/ws/{client_id}") async def websocket_endpoint(ws: WebSocket, client_id: str): await ws.accept() try: while True: continue_exe = False message_origin = await ws.receive() if message_origin.get('type') == 'websocket.disconnect': break if message_origin.get('text'): message_use = json.loads(message_origin.get('text')) if message_origin.get('bytes'): audio_data = message_origin.get('bytes') res = stt_model.generate( input=audio_data, cache={}, language="auto", use_itn=True, batch_size_s=60, merge_vad=True, merge_length_s=15, ) prompt = rich_transcription_postprocess(res[0]["text"]) else: prompt = message_use.get('prompt', '') print(f'用户问题:{prompt}') history = message_use.get('history',[]) print(f'之前对话历史: {history}') for item in history: if 'files' in item: del item['files'] file_names = message_use.get('file_names',[]) db_param_temp = message_use.get('db_param', '') if db_param_temp: db_param = db_en_map.get(db_param_temp).get('name') table_map = db_en_map.get(db_param_temp).get('table_cn_map') table_name_temp = message_use.get('tables', []) table_name = [table_map.get(i) for i in table_name_temp] else: db_param = message_use.get('db_param', '') table_name = message_use.get('tables', []) if len(file_names) > 0 and len(db_param) > 0: status_map.update({client_id:'normal'}) elif len(file_names)>0: file_names = [f'./upload/{i}' for i in file_names] status_map.update({client_id:'code'}) elif len(db_param)>0: status_map.update({client_id:'sql'}) elif not status_map.get(client_id, None): status_map.update({client_id:'normal'}) changeable_prompt = prompt + (f' tables_name:{table_name}' if table_name else '') + (f' files:{file_names}' if file_names else '') use_tools = await validate_use_tools(changeable_prompt) if use_tools: # if False: func_file_names = [os.path.join(os.path.join(BASE_UPLOAD_DIRECTORY, client_id), 'upload/' + s) for s in message_use.get('file_names',[])] new_prompt = f"{prompt} file_path:{func_file_names}" function_answer, data_result = generate_result(prompt=new_prompt) print(data_result) if data_result: data_all_result = [f'{file_url}{x}' for x in data_result] data_big_result = [f'{file_url}{x}' for x in data_result if os.path.getsize(x.replace(STATIC_DIR, BASE_UPLOAD_DIRECTORY)) > 15 * 1024*1024] final_result = list(set(data_all_result) - set(data_big_result)) print(f'function data_send: \n {final_result}\n') final_answer = function_answer + '\n 下载链接: \n' + ','.join(data_big_result) if data_big_result else function_answer await ws.send_json({'text':final_answer,'files': final_result}) await ws.send_text('end') else: # await ws.send_json({'text':function_result,'files': []}) # await ws.send_text('end') continue_exe = True else: continue_exe = True ##大单数据分析 if prompt == '生成零售加盟大单报表': print(f'文件列表:{file_names}') excel_file = file_names[0] if excel_file: export_file = datetime.now().strftime('%Y%m%d%H%M%S') + '.xlsx' temp_directory = os.path.join(BASE_UPLOAD_DIRECTORY, client_id) user_directory = os.path.join(temp_directory, 'upload') file_location = os.path.join(user_directory, export_file) print(f'生成零售加盟大单报表文件:{file_location}') data_processor(excel_file, file_location) await ws.send_json({'text': '测试成功', 'files': [f'{file_url}{file_location}']}) await ws.send_text('end') continue_exe = False else: await ws.send_json({'text': '请先上传excel表格', 'files': ''}) await ws.send_text('end') continue_exe = False if continue_exe: print(f'继续执行: {continue_exe}') analyze_detect = await detect_analyze_agent.a_generate_reply(messages=[{'role':'user', 'content':prompt}]) user_content = analyze_detect.get('content') if not isinstance(analyze_detect, str) else analyze_detect print(f'数据分析意图识别:{user_content} \n') if '否' in user_content: status_map.update({client_id:'normal'}) if status_map.get(client_id) == 'code': try: print('code model') if client_id in code_instance_map: # 如果实例已存在,更新属性 code_instance_map[client_id].files = file_names code_instance_map[client_id].history = history await code_instance_map[client_id].ws_send_data(ws=ws, prompt=prompt) else: # 如果实例不存在,创建新实例 code_instance_map[client_id] = code_analyze(client_id=client_id, history=history, files=file_names) await code_instance_map[client_id].ws_send_data(ws=ws, prompt=prompt) except Exception as e: await ws.send_json({'text':f'出错啦,遇到了一些问题:{e}','files': []}) await ws.send_text('end') elif status_map.get(client_id) == 'sql': # try: print('sql model') db_param = get_db_param(db_param) if client_id in sql_instance_map: if db_param: sql_instance_map[client_id].db_param = db_param if table_name: sql_instance_map[client_id].table_name = table_name await sql_instance_map[client_id].ws_send_data(ws=ws, raw_prompt=prompt, file_url=file_url) else: # 如果实例不存在,创建新实例 sql_instance_map[client_id] = sql_analyze(data_engineer=data_engineer, client_id=client_id, db_param=db_param, table_name=table_name, file_path=STATIC_DIR) await sql_instance_map[client_id].ws_send_data(ws=ws, raw_prompt=prompt, file_url=file_url) # except Exception as e: # await ws.send_json({'text':f'出错啦,遇到了一些问题:{e}','files': []}) # await ws.send_text('end') elif status_map.get(client_id) == 'normal': print('normal model') normal_answer = ConversableAgent( 'normal', system_message="You are a helpful assistant to answer the user's question as best as you can.", llm_config=llm_config) messages = deepcopy(history) messages.append({'role':'user', 'content': prompt}) answer = await normal_answer.a_generate_reply(messages=messages) answer = answer.get('content','很抱歉我可能无法回答你的问题') if isinstance(answer, dict) else answer print(f'final_answer: \n {answer}\n') await ws.send_json({'text':answer,'files': []}) await ws.send_text('end') # except Exception as e: # print(f'there is some error:{e}') # await ws.send_json({'text':f'报错啦, 请重新问一下问题或者联系管理员','files': []}) except WebSocketDisconnect: # 当连接断开时,移除对应的实例 print('websocket正常断开') for i, map in enumerate([code_instance_map, sql_instance_map, status_map]): if client_id in map: if i == 0: map[client_id].executor.stop() del map[client_id] print(f"Client {client_id} disconnected, {map} removed") finally: # 当连接断开时,移除对应的实例 for i, map in enumerate([code_instance_map, sql_instance_map, status_map]): if client_id in map: if i == 0: container = docker_client.containers.get(map[client_id].jupyter_server._container_id) container.stop() container.remove() # map[client_id].jupyter_server.stop() # print(f'{client_id} stop the jupyter_server') # map[client_id].executor.stop() # print(f'{client_id} stop the executer') del map[client_id] print(f"Client {client_id} disconnected") print(f'The map after removed: \ncode map:{code_instance_map} \nsql map: {sql_instance_map}, \nstatus map: {status_map}') @app.post("/uploadfile/") async def create_upload_file(file: UploadFile = File(...), client_id: str = Form(...)): temp_directory = os.path.join(BASE_UPLOAD_DIRECTORY, client_id) user_directory = os.path.join(temp_directory, 'upload') if not os.path.exists(user_directory): os.makedirs(user_directory) os.chmod(user_directory, 0o777) # 设置用户目录权限为777 file_location = os.path.join(user_directory, file.filename) try: with open(file_location, "wb+") as file_object: file_object.write(file.file.read()) os.chmod(file_location, 0o777) # 设置文件权限为777 return JSONResponse(content={ "message": f"文件 '{file.filename}' 上传成功", "client_id": client_id, "file_path": file_location }, status_code=200) except Exception as e: return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500) @app.post("/database") async def get_table_name(): database = [] try: for db in db_list_2: all_db = {} conn = pymssql.connect(**db) cursor = conn.cursor() cursor.execute("SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE'") table_names = cursor.fetchall() table_list = [s[0] for s in table_names] all_db['db_param'] = db_cn_map.get(db['database']).get('name') table_map = db_cn_map.get(db['database']).get('table_cn_map') table_list = [table_map.get(i) for i in table_list] # all_db['db_param'] = db['database'] all_db['tables'] = table_list database.append(all_db) return JSONResponse(content={"databaseInfo": database}, status_code=200) except Exception as e: return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500) @app.post("/get_client_id") async def get_table_name(): # 生成一个基于UUID4的随机数 try: random_uuid = str(uuid.uuid4()) return JSONResponse(content={"client_id": random_uuid}, status_code=200) except Exception as e: return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=port)