server.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. # basic import
  2. from datetime import datetime
  3. import uvicorn, json, os, uuid, docker, pymssql, autogen
  4. from autogen import ConversableAgent
  5. from copy import deepcopy
  6. from config import llm_config, file_url, BASE_UPLOAD_DIRECTORY, STATIC_DIR, db_list_2, db_cn_map, db_en_map
  7. # api server import
  8. from fastapi import FastAPI, WebSocket, WebSocketDisconnect
  9. from fastapi.staticfiles import StaticFiles
  10. from fastapi import FastAPI, File, UploadFile, Form
  11. from fastapi.responses import JSONResponse
  12. from fastapi.middleware.cors import CORSMiddleware
  13. # functioncall import
  14. from agents import data_engineer, detect_analyze_agent
  15. from tools import validate_use_tools, generate_result
  16. from data_processor.data_processor import data_processor
  17. # sql import
  18. from sql_instruments import sql_analyze_father
  19. from util import get_timestamp, get_db_param
  20. # code import
  21. from code_instruments import code_analyze_father
  22. # asr import
  23. from config import stt_model, port
  24. from funasr.utils.postprocess_utils import rich_transcription_postprocess
  25. docker_client = docker.from_env()
  26. app = FastAPI()
  27. if not os.path.exists(BASE_UPLOAD_DIRECTORY):
  28. os.makedirs(BASE_UPLOAD_DIRECTORY)
  29. os.makedirs(STATIC_DIR, exist_ok=True)
  30. app.mount(STATIC_DIR, StaticFiles(directory=BASE_UPLOAD_DIRECTORY), name="static")
  31. app.add_middleware(
  32. CORSMiddleware,
  33. allow_origins=["*"],
  34. allow_credentials=True,
  35. allow_methods=["*"],
  36. allow_headers=["*"],
  37. )
  38. code_instance_map = {}
  39. sql_instance_map = {}
  40. status_map = {}
  41. class sql_analyze(sql_analyze_father):
  42. def __init__(self, data_engineer:autogen.AssistantAgent, client_id: str, db_param: dict, table_name=[], file_path=STATIC_DIR) -> None:
  43. super().__init__(data_engineer, client_id, db_param, table_name)
  44. self.file_path = file_path
  45. async def ws_send_data(self, ws:WebSocket, raw_prompt, file_url):
  46. self.upload_file_path = os.path.join(BASE_UPLOAD_DIRECTORY, self.client_id)
  47. if not os.path.exists(self.upload_file_path):
  48. os.makedirs(self.upload_file_path)
  49. os.chmod(self.upload_file_path, 0o777) # 设置用户目录权限为777
  50. self.run_sql_result_file_excel = os.path.join(self.upload_file_path, get_timestamp() + '_result.xlsx')
  51. self.plot_data_file = os.path.join(self.upload_file_path, get_timestamp() + '_result_plot.html')
  52. sql, results, summary_content = await self.run_sql_analyze(raw_prompt)
  53. if len(results)==0 or results == '[]' or 'Error occurred' in results:
  54. await ws.send_json({'text': summary_content,'files': []})
  55. await ws.send_text('end')
  56. else:
  57. data = []
  58. self.make_data(results, self.run_sql_result_file_excel, self.plot_data_file)
  59. data.append(f'{file_url}' + self.file_path + '/' + self.client_id + '/' + self.run_sql_result_file_excel.split('/')[-1])
  60. data.append(f'{file_url}' + self.file_path + '/' + self.client_id + '/' + self.plot_data_file.split('/')[-1])
  61. print(f'data_send: \n {data}\n')
  62. await ws.send_json({'text': summary_content,'files': []})
  63. msg = {
  64. 'text': '',
  65. 'files': data,
  66. }
  67. await ws.send_json(msg)
  68. await ws.send_text('end')
  69. class code_analyze(code_analyze_father):
  70. def __init__(self, client_id: str, history: list, files: list) -> None:
  71. super().__init__(client_id, history, files)
  72. async def ws_send_data(self, ws: WebSocket, prompt):
  73. summary, data_all_send = await self.run_seperate_jupyter_auto(prompt=prompt)
  74. await ws.send_json({'text': summary,'files': []})
  75. await ws.send_json({'text': '','files': data_all_send})
  76. await ws.send_text('end')
  77. @app.websocket("/ws/{client_id}")
  78. async def websocket_endpoint(ws: WebSocket, client_id: str):
  79. await ws.accept()
  80. try:
  81. while True:
  82. continue_exe = False
  83. message_origin = await ws.receive()
  84. if message_origin.get('type') == 'websocket.disconnect':
  85. break
  86. if message_origin.get('text'):
  87. message_use = json.loads(message_origin.get('text'))
  88. if message_origin.get('bytes'):
  89. audio_data = message_origin.get('bytes')
  90. res = stt_model.generate(
  91. input=audio_data,
  92. cache={},
  93. language="auto",
  94. use_itn=True,
  95. batch_size_s=60,
  96. merge_vad=True,
  97. merge_length_s=15,
  98. )
  99. prompt = rich_transcription_postprocess(res[0]["text"])
  100. else:
  101. prompt = message_use.get('prompt', '')
  102. print(f'用户问题:{prompt}')
  103. history = message_use.get('history',[])
  104. print(f'之前对话历史: {history}')
  105. for item in history:
  106. if 'files' in item:
  107. del item['files']
  108. file_names = message_use.get('file_names',[])
  109. db_param_temp = message_use.get('db_param', '')
  110. if db_param_temp:
  111. db_param = db_en_map.get(db_param_temp).get('name')
  112. table_map = db_en_map.get(db_param_temp).get('table_cn_map')
  113. table_name_temp = message_use.get('tables', [])
  114. table_name = [table_map.get(i) for i in table_name_temp]
  115. else:
  116. db_param = message_use.get('db_param', '')
  117. table_name = message_use.get('tables', [])
  118. if len(file_names) > 0 and len(db_param) > 0:
  119. status_map.update({client_id:'normal'})
  120. elif len(file_names)>0:
  121. file_names = [f'./upload/{i}' for i in file_names]
  122. status_map.update({client_id:'code'})
  123. elif len(db_param)>0:
  124. status_map.update({client_id:'sql'})
  125. elif not status_map.get(client_id, None):
  126. status_map.update({client_id:'normal'})
  127. changeable_prompt = prompt + (f' tables_name:{table_name}' if table_name else '') + (f' files:{file_names}' if file_names else '')
  128. use_tools = await validate_use_tools(changeable_prompt)
  129. if use_tools:
  130. # if False:
  131. func_file_names = [os.path.join(os.path.join(BASE_UPLOAD_DIRECTORY, client_id), 'upload/' + s) for s in message_use.get('file_names',[])]
  132. new_prompt = f"{prompt} file_path:{func_file_names}"
  133. function_answer, data_result = generate_result(prompt=new_prompt)
  134. print(data_result)
  135. if data_result:
  136. data_all_result = [f'{file_url}{x}' for x in data_result]
  137. data_big_result = [f'{file_url}{x}' for x in data_result if os.path.getsize(x.replace(STATIC_DIR, BASE_UPLOAD_DIRECTORY)) > 15 * 1024*1024]
  138. final_result = list(set(data_all_result) - set(data_big_result))
  139. print(f'function data_send: \n {final_result}\n')
  140. final_answer = function_answer + '\n 下载链接: \n' + ','.join(data_big_result) if data_big_result else function_answer
  141. await ws.send_json({'text':final_answer,'files': final_result})
  142. await ws.send_text('end')
  143. else:
  144. # await ws.send_json({'text':function_result,'files': []})
  145. # await ws.send_text('end')
  146. continue_exe = True
  147. else:
  148. continue_exe = True
  149. ##大单数据分析
  150. if prompt == '生成零售加盟大单报表':
  151. print(f'文件列表:{file_names}')
  152. excel_file = file_names[0]
  153. if excel_file:
  154. export_file = datetime.now().strftime('%Y%m%d%H%M%S') + '.xlsx'
  155. temp_directory = os.path.join(BASE_UPLOAD_DIRECTORY, client_id)
  156. user_directory = os.path.join(temp_directory, 'upload')
  157. file_location = os.path.join(user_directory, export_file)
  158. print(f'生成零售加盟大单报表文件:{file_location}')
  159. data_processor(excel_file, file_location)
  160. await ws.send_json({'text': '测试成功', 'files': [f'{file_url}{file_location}']})
  161. await ws.send_text('end')
  162. continue_exe = False
  163. else:
  164. await ws.send_json({'text': '请先上传excel表格', 'files': ''})
  165. await ws.send_text('end')
  166. continue_exe = False
  167. if continue_exe:
  168. print(f'继续执行: {continue_exe}')
  169. analyze_detect = await detect_analyze_agent.a_generate_reply(messages=[{'role':'user', 'content':prompt}])
  170. user_content = analyze_detect.get('content') if not isinstance(analyze_detect, str) else analyze_detect
  171. print(f'数据分析意图识别:{user_content} \n')
  172. if '否' in user_content:
  173. status_map.update({client_id:'normal'})
  174. if status_map.get(client_id) == 'code':
  175. try:
  176. print('code model')
  177. if client_id in code_instance_map:
  178. # 如果实例已存在,更新属性
  179. code_instance_map[client_id].files = file_names
  180. code_instance_map[client_id].history = history
  181. await code_instance_map[client_id].ws_send_data(ws=ws, prompt=prompt)
  182. else:
  183. # 如果实例不存在,创建新实例
  184. code_instance_map[client_id] = code_analyze(client_id=client_id, history=history, files=file_names)
  185. await code_instance_map[client_id].ws_send_data(ws=ws, prompt=prompt)
  186. except Exception as e:
  187. await ws.send_json({'text':f'出错啦,遇到了一些问题:{e}','files': []})
  188. await ws.send_text('end')
  189. elif status_map.get(client_id) == 'sql':
  190. # try:
  191. print('sql model')
  192. db_param = get_db_param(db_param)
  193. if client_id in sql_instance_map:
  194. if db_param:
  195. sql_instance_map[client_id].db_param = db_param
  196. if table_name:
  197. sql_instance_map[client_id].table_name = table_name
  198. await sql_instance_map[client_id].ws_send_data(ws=ws, raw_prompt=prompt, file_url=file_url)
  199. else:
  200. # 如果实例不存在,创建新实例
  201. sql_instance_map[client_id] = sql_analyze(data_engineer=data_engineer, client_id=client_id, db_param=db_param, table_name=table_name, file_path=STATIC_DIR)
  202. await sql_instance_map[client_id].ws_send_data(ws=ws, raw_prompt=prompt, file_url=file_url)
  203. # except Exception as e:
  204. # await ws.send_json({'text':f'出错啦,遇到了一些问题:{e}','files': []})
  205. # await ws.send_text('end')
  206. elif status_map.get(client_id) == 'normal':
  207. print('normal model')
  208. normal_answer = ConversableAgent(
  209. 'normal',
  210. system_message="You are a helpful assistant to answer the user's question as best as you can.",
  211. llm_config=llm_config)
  212. messages = deepcopy(history)
  213. messages.append({'role':'user', 'content': prompt})
  214. answer = await normal_answer.a_generate_reply(messages=messages)
  215. answer = answer.get('content','很抱歉我可能无法回答你的问题') if isinstance(answer, dict) else answer
  216. print(f'final_answer: \n {answer}\n')
  217. await ws.send_json({'text':answer,'files': []})
  218. await ws.send_text('end')
  219. # except Exception as e:
  220. # print(f'there is some error:{e}')
  221. # await ws.send_json({'text':f'报错啦, 请重新问一下问题或者联系管理员','files': []})
  222. except WebSocketDisconnect:
  223. # 当连接断开时,移除对应的实例
  224. print('websocket正常断开')
  225. for i, map in enumerate([code_instance_map, sql_instance_map, status_map]):
  226. if client_id in map:
  227. if i == 0:
  228. map[client_id].executor.stop()
  229. del map[client_id]
  230. print(f"Client {client_id} disconnected, {map} removed")
  231. finally:
  232. # 当连接断开时,移除对应的实例
  233. for i, map in enumerate([code_instance_map, sql_instance_map, status_map]):
  234. if client_id in map:
  235. if i == 0:
  236. container = docker_client.containers.get(map[client_id].jupyter_server._container_id)
  237. container.stop()
  238. container.remove()
  239. # map[client_id].jupyter_server.stop()
  240. # print(f'{client_id} stop the jupyter_server')
  241. # map[client_id].executor.stop()
  242. # print(f'{client_id} stop the executer')
  243. del map[client_id]
  244. print(f"Client {client_id} disconnected")
  245. print(f'The map after removed: \ncode map:{code_instance_map} \nsql map: {sql_instance_map}, \nstatus map: {status_map}')
  246. @app.post("/uploadfile/")
  247. async def create_upload_file(file: UploadFile = File(...), client_id: str = Form(...)):
  248. temp_directory = os.path.join(BASE_UPLOAD_DIRECTORY, client_id)
  249. user_directory = os.path.join(temp_directory, 'upload')
  250. if not os.path.exists(user_directory):
  251. os.makedirs(user_directory)
  252. os.chmod(user_directory, 0o777) # 设置用户目录权限为777
  253. file_location = os.path.join(user_directory, file.filename)
  254. try:
  255. with open(file_location, "wb+") as file_object:
  256. file_object.write(file.file.read())
  257. os.chmod(file_location, 0o777) # 设置文件权限为777
  258. return JSONResponse(content={
  259. "message": f"文件 '{file.filename}' 上传成功",
  260. "client_id": client_id,
  261. "file_path": file_location
  262. }, status_code=200)
  263. except Exception as e:
  264. return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500)
  265. @app.post("/database")
  266. async def get_table_name():
  267. database = []
  268. try:
  269. for db in db_list_2:
  270. all_db = {}
  271. conn = pymssql.connect(**db)
  272. cursor = conn.cursor()
  273. cursor.execute("SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE'")
  274. table_names = cursor.fetchall()
  275. table_list = [s[0] for s in table_names]
  276. all_db['db_param'] = db_cn_map.get(db['database']).get('name')
  277. table_map = db_cn_map.get(db['database']).get('table_cn_map')
  278. table_list = [table_map.get(i) for i in table_list]
  279. # all_db['db_param'] = db['database']
  280. all_db['tables'] = table_list
  281. database.append(all_db)
  282. return JSONResponse(content={"databaseInfo": database}, status_code=200)
  283. except Exception as e:
  284. return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500)
  285. @app.post("/get_client_id")
  286. async def get_table_name():
  287. # 生成一个基于UUID4的随机数
  288. try:
  289. random_uuid = str(uuid.uuid4())
  290. return JSONResponse(content={"client_id": random_uuid}, status_code=200)
  291. except Exception as e:
  292. return JSONResponse(content={"message": f"发生错误: {str(e)}"}, status_code=500)
  293. if __name__ == "__main__":
  294. import uvicorn
  295. uvicorn.run(app, host="0.0.0.0", port=port)