server.py 15 KB


  1. # basic import
  2. import uvicorn, json, os, uuid, docker, pymssql, autogen
  3. from autogen import ConversableAgent
  4. from copy import deepcopy
  5. from config import llm_config, file_url, BASE_UPLOAD_DIRECTORY, STATIC_DIR, db_list_2, db_cn_map, db_en_map
  6. # api server import
  7. from fastapi import FastAPI, WebSocket, WebSocketDisconnect
  8. from fastapi.staticfiles import StaticFiles
  9. from fastapi import FastAPI, File, UploadFile, Form
  10. from fastapi.responses import JSONResponse
  11. from fastapi.middleware.cors import CORSMiddleware
  12. # functioncall import
  13. from agents import data_engineer, detect_analyze_agent
  14. from tools import validate_use_tools, generate_result
  15. # sql import
  16. from sql_instruments import sql_analyze_father
  17. from util import get_timestamp, get_db_param
  18. # code import
  19. from code_instruments import code_analyze_father
  20. # asr import
  21. from config import stt_model, port
  22. from funasr.utils.postprocess_utils import rich_transcription_postprocess
  23. docker_client = docker.from_env()
  24. app = FastAPI()
  25. if not os.path.exists(BASE_UPLOAD_DIRECTORY):
  26. os.makedirs(BASE_UPLOAD_DIRECTORY)
  27. os.makedirs(STATIC_DIR, exist_ok=True)
  28. app.mount(STATIC_DIR, StaticFiles(directory=BASE_UPLOAD_DIRECTORY), name="static")
  29. app.add_middleware(
  30. CORSMiddleware,
  31. allow_origins=["*"],
  32. allow_credentials=True,
  33. allow_methods=["*"],
  34. allow_headers=["*"],
  35. )
  36. code_instance_map = {}
  37. sql_instance_map = {}
  38. status_map = {}
  39. class sql_analyze(sql_analyze_father):
  40. def __init__(self, data_engineer:autogen.AssistantAgent, client_id: str, db_param: dict, table_name=[], file_path=STATIC_DIR) -> None:
  41. super().__init__(data_engineer, client_id, db_param, table_name)
  42. self.file_path = file_path
  43. async def ws_send_data(self, ws:WebSocket, raw_prompt, file_url):
  44. self.upload_file_path = os.path.join(BASE_UPLOAD_DIRECTORY, self.client_id)
  45. if not os.path.exists(self.upload_file_path):
  46. os.makedirs(self.upload_file_path)
  47. os.chmod(self.upload_file_path, 0o777) # 设置用户目录权限为777
  48. self.run_sql_result_file_excel = os.path.join(self.upload_file_path, get_timestamp() + '_result.xlsx')
  49. self.plot_data_file = os.path.join(self.upload_file_path, get_timestamp() + '_result_plot.html')
  50. sql, results, summary_content = await self.run_sql_analyze(raw_prompt)
  51. if len(results)==0 or results == '[]' or 'Error occurred' in results:
  52. await ws.send_json({'text': summary_content,'files': []})
  53. await ws.send_text('end')
  54. else:
  55. data = []
  56. self.make_data(results, self.run_sql_result_file_excel, self.plot_data_file)
  57. data.append(f'{file_url}' + self.file_path + '/' + self.client_id + '/' + self.run_sql_result_file_excel.split('/')[-1])
  58. data.append(f'{file_url}' + self.file_path + '/' + self.client_id + '/' + self.plot_data_file.split('/')[-1])
  59. print(f'data_send: \n {data}\n')
  60. await ws.send_json({'text': summary_content,'files': []})
  61. msg = {
  62. 'text': '',
  63. 'files': data,
  64. }
  65. await ws.send_json(msg)
  66. await ws.send_text('end')
  67. class code_analyze(code_analyze_father):
  68. def __init__(self, client_id: str, history: list, files: list) -> None:
  69. super().__init__(client_id, history, files)
  70. async def ws_send_data(self, ws: WebSocket, prompt):
  71. summary, data_all_send = await self.run_seperate_jupyter_auto(prompt=prompt)
  72. await ws.send_json({'text': summary,'files': []})
  73. await ws.send_json({'text': '','files': data_all_send})
  74. await ws.send_text('end')
  75. @app.websocket("/ws/{client_id}")
  76. async def websocket_endpoint(ws: WebSocket, client_id: str):
  77. await ws.accept()
  78. try:
  79. while True:
  80. continue_exe = False
  81. message_origin = await ws.receive()
  82. if message_origin.get('type') == 'websocket.disconnect':
  83. break
  84. if message_origin.get('text'):
  85. message_use = json.loads(message_origin.get('text'))
  86. if message_origin.get('bytes'):
  87. audio_data = message_origin.get('bytes')
  88. res = stt_model.generate(
  89. input=audio_data,
  90. cache={},
  91. language="auto",
  92. use_itn=True,
  93. batch_size_s=60,
  94. merge_vad=True,
  95. merge_length_s=15,
  96. )
  97. prompt = rich_transcription_postprocess(res[0]["text"])
  98. else:
  99. prompt = message_use.get('prompt', '')
  100. print(f'用户问题:{prompt}')
  101. history = message_use.get('history',[])
  102. print(f'之前对话历史: {history}')
  103. for item in history:
  104. if 'files' in item:
  105. del item['files']
  106. file_names = message_use.get('file_names',[])
  107. db_param_temp = message_use.get('db_param', '')
  108. if db_param_temp:
  109. db_param = db_en_map.get(db_param_temp).get('name')
  110. table_map = db_en_map.get(db_param_temp).get('table_cn_map')
  111. table_name_temp = message_use.get('tables', [])
  112. table_name = [table_map.get(i) for i in table_name_temp]
  113. else:
  114. db_param = message_use.get('db_param', '')
  115. table_name = message_use.get('tables', [])
  116. if len(file_names) > 0 and len(db_param) > 0:
  117. status_map.update({client_id:'normal'})
  118. elif len(file_names)>0:
  119. file_names = [f'./upload/{i}' for i in file_names]
  120. status_map.update({client_id:'code'})
  121. elif len(db_param)>0:
  122. status_map.update({client_id:'sql'})
  123. elif not status_map.get(client_id, None):
  124. status_map.update({client_id:'normal'})
  125. changeable_prompt = prompt + (f' tables_name:{table_name}' if table_name else '') + (f' files:{file_names}' if file_names else '')
  126. use_tools = await validate_use_tools(changeable_prompt)
  127. if use_tools:
  128. # if False:
  129. func_file_names = [os.path.join(os.path.join(BASE_UPLOAD_DIRECTORY, client_id), 'upload/' + s) for s in message_use.get('file_names',[])]
  130. new_prompt = f"{prompt} file_path:{func_file_names}"
  131. function_answer, data_result = generate_result(prompt=new_prompt)
  132. print(data_result)
  133. if data_result:
  134. data_all_result = [f'{file_url}{x}' for x in data_result]
  135. data_big_result = [f'{file_url}{x}' for x in data_result if os.path.getsize(x.replace(STATIC_DIR, BASE_UPLOAD_DIRECTORY)) > 15 * 1024*1024]
  136. final_result = list(set(data_all_result) - set(data_big_result))
  137. print(f'function data_send: \n {final_result}\n')
  138. final_answer = function_answer + '\n 下载链接: \n' + ','.join(data_big_result) if data_big_result else function_answer
  139. await ws.send_json({'text':final_answer,'files': final_result})
  140. await ws.send_text('end')
  141. else:
  142. # await ws.send_json({'text':function_result,'files': []})
  143. # await ws.send_text('end')
  144. continue_exe = True
  145. else:
  146. continue_exe = True
  147. if continue_exe:
  148. print(f'继续执行: {continue_exe}')
  149. analyze_detect = await detect_analyze_agent.a_generate_reply(messages=[{'role':'user', 'content':prompt}])
  150. user_content = analyze_detect.get('content') if not isinstance(analyze_detect, str) else analyze_detect
  151. print(f'数据分析意图识别:{user_content} \n')
  152. if '否' in user_content:
  153. status_map.update({client_id:'normal'})
  154. if status_map.get(client_id) == 'code':
  155. try:
  156. print('code model')
  157. if client_id in code_instance_map:
  158. # 如果实例已存在,更新属性
  159. code_instance_map[client_id].files = file_names
  160. code_instance_map[client_id].history = history
  161. await code_instance_map[client_id].ws_send_data(ws=ws, prompt=prompt)
  162. else:
  163. # 如果实例不存在,创建新实例
  164. code_instance_map[client_id] = code_analyze(client_id=client_id, history=history, files=file_names)
  165. await code_instance_map[client_id].ws_send_data(ws=ws, prompt=prompt)
  166. except Exception as e:
  167. await ws.send_json({'text':f'出错啦,遇到了一些问题:{e}','files': []})
  168. await ws.send_text('end')
  169. elif status_map.get(client_id) == 'sql':
  170. # try:
  171. print('sql model')
  172. db_param = get_db_param(db_param)
  173. if client_id in sql_instance_map:
  174. if db_param:
  175. sql_instance_map[client_id].db_param = db_param
  176. if table_name:
  177. sql_instance_map[client_id].table_name = table_name
  178. await sql_instance_map[client_id].ws_send_data(ws=ws, raw_prompt=prompt, file_url=file_url)
  179. else:
  180. # 如果实例不存在,创建新实例
  181. sql_instance_map[client_id] = sql_analyze(data_engineer=data_engineer, client_id=client_id, db_param=db_param, table_name=table_name, file_path=STATIC_DIR)
  182. await sql_instance_map[client_id].ws_send_data(ws=ws, raw_prompt=prompt, file_url=file_url)
  183. # except Exception as e:
  184. # await ws.send_json({'text':f'出错啦,遇到了一些问题:{e}','files': []})
  185. # await ws.send_text('end')
  186. elif status_map.get(client_id) == 'normal':
  187. print('normal model')
  188. normal_answer = ConversableAgent(
  189. 'normal',
  190. system_message="You are a helpful assistant to answer the user's question as best as you can.",
  191. llm_config=llm_config)
  192. messages = deepcopy(history)
  193. messages.append({'role':'user', 'content': prompt})
  194. answer = await normal_answer.a_generate_reply(messages=messages)
  195. answer = answer.get('content','很抱歉我可能无法回答你的问题') if isinstance(answer, dict) else answer
  196. print(f'final_answer: \n {answer}\n')
  197. await ws.send_json({'text':answer,'files': []})
  198. await ws.send_text('end')
  199. # except Exception as e:
  200. # print(f'there is some error:{e}')
  201. # await ws.send_json({'text':f'报错啦, 请重新问一下问题或者联系管理员','files': []})
  202. except WebSocketDisconnect:
  203. # 当连接断开时,移除对应的实例
  204. print('websocket正常断开')
  205. for i, map in enumerate([code_instance_map, sql_instance_map, status_map]):
  206. if client_id in map:
  207. if i == 0:
  208. map[client_id].executor.stop()
  209. del map[client_id]
  210. print(f"Client {client_id} disconnected, {map} removed")
  211. finally:
  212. # 当连接断开时,移除对应的实例
  213. for i, map in enumerate([code_instance_map, sql_instance_map, status_map]):
  214. if client_id in map:
  215. if i == 0:
  216. container = docker_client.containers.get(map[client_id].jupyter_server._container_id)
  217. container.stop()
  218. container.remove()
  219. # map[client_id].jupyter_server.stop()
  220. # print(f'{client_id} stop the jupyter_server')
  221. # map[client_id].executor.stop()
  222. # print(f'{client_id} stop the executer')
  223. del map[client_id]
  224. print(f"Client {client_id} disconnected")
  225. print(f'The map after removed: \ncode map:{code_instance_map} \nsql map: {sql_instance_map}, \nstatus map: {status_map}')
  226. @app.post("/uploadfile/")
  227. async def create_upload_file(file: UploadFile = File(...), client_id: str = Form(...)):
  228. temp_directory = os.path.join(BASE_UPLOAD_DIRECTORY, client_id)
  229. user_directory = os.path.join(temp_directory, 'upload')
  230. if not os.path.exists(user_directory):
  231. os.makedirs(user_directory)
  232. os.chmod(user_directory, 0o777) # 设置用户目录权限为777
  233. file_location = os.path.join(user_directory, file.filename)
  234. try:
  235. with open(file_location, "wb+") as file_object:
  236. file_object.write(file.file.read())
  237. os.chmod(file_location, 0o777) # 设置文件权限为777
  238. return JSONResponse(content={
  239. "message": f"文件 '{file.filename}' 上传成功",
  240. "client_id": client_id,
  241. "file_path": file_location
  242. }, status_code=200)
  243. except Exception as e:
  244. return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500)
  245. @app.post("/database")
  246. async def get_table_name():
  247. database = []
  248. try:
  249. for db in db_list_2:
  250. all_db = {}
  251. conn = pymssql.connect(**db)
  252. cursor = conn.cursor()
  253. cursor.execute("SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE'")
  254. table_names = cursor.fetchall()
  255. table_list = [s[0] for s in table_names]
  256. all_db['db_param'] = db_cn_map.get(db['database']).get('name')
  257. table_map = db_cn_map.get(db['database']).get('table_cn_map')
  258. table_list = [table_map.get(i) for i in table_list]
  259. # all_db['db_param'] = db['database']
  260. all_db['tables'] = table_list
  261. database.append(all_db)
  262. return JSONResponse(content={"databaseInfo": database}, status_code=200)
  263. except Exception as e:
  264. return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500)
  265. @app.post("/get_client_id")
  266. async def get_table_name():
  267. # 生成一个基于UUID4的随机数
  268. try:
  269. random_uuid = str(uuid.uuid4())
  270. return JSONResponse(content={"client_id": random_uuid}, status_code=200)
  271. except Exception as e:
  272. return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500)
  273. if __name__ == "__main__":
  274. import uvicorn
  275. uvicorn.run(app, host="0.0.0.0", port=port)