network_tools.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import time
  2. import json
  3. import urllib.request
  4. import urllib.parse
  5. from backend.utils.logger_config import setup_logger
  6. from backend.utils.system_config import Config
  7. logger = setup_logger(__name__)
  8. def queue_prompt(textPrompt: str, config: Config) -> dict:
  9. """
  10. 向服务器队列发送提示词
  11. Args:
  12. textPrompt: 提示词
  13. config: 配置
  14. Returns:
  15. dict: 提示词ID
  16. """
  17. p = {"prompt": textPrompt, "client_id": config.client_id}
  18. data = json.dumps(p).encode('utf-8')
  19. req = urllib.request.Request(f"http://{config.server_address}/prompt", data=data)
  20. logger.info(f"Queue prompt: {textPrompt}")
  21. try:
  22. response = urllib.request.urlopen(req)
  23. return json.loads(response.read())
  24. except Exception as e:
  25. logger.error(f"Failed to queue prompt: {e}")
  26. raise
  27. def get_image(fileName: str, subFolder: str, folder_type: str, config: Config) -> bytes:
  28. data = {"filename": fileName, "subfolder": subFolder, "type": folder_type}
  29. url_values = urllib.parse.urlencode(data)
  30. try:
  31. with urllib.request.urlopen(f"http://{config.server_address}/view?{url_values}") as response:
  32. return response.read()
  33. except Exception as e:
  34. logger.error(f"Failed to get image: {e}")
  35. raise
  36. def get_history(prompt_id: str, config: Config) -> dict:
  37. try:
  38. with urllib.request.urlopen(f"http://{config.server_address}/history/{prompt_id}") as response:
  39. return json.loads(response.read())
  40. except Exception as e:
  41. logger.error(f"Failed to get history: {e}")
  42. raise
  43. def process_images(images: list, config) -> list:
  44. images_output = []
  45. for image in images:
  46. image_data = get_image(image['filename'], image['subfolder'], image['type'], config)
  47. images_output.append(image_data)
  48. return images_output
  49. def process_videos(videos: list, config) -> list:
  50. videos_output = []
  51. for video in videos:
  52. video_data = get_image(video['filename'], video['subfolder'], video['type'], config)
  53. videos_output.append(video_data)
  54. return videos_output
  55. def get_images(prompt, config: Config, timeout=160):
  56. prompt_id = queue_prompt(prompt, config)['prompt_id']
  57. output_images = {}
  58. history_prompt = ""
  59. completed = False
  60. start_time = time.time()
  61. logger.info(f"Start time: {start_time}")
  62. while not completed:
  63. try:
  64. history = get_history(prompt_id, config)[prompt_id]
  65. if history is not None:
  66. for node_id, node_output in history['outputs'].items():
  67. if 'images' in node_output:
  68. output_images[node_id] = process_images(node_output['images'], config)
  69. completed = True
  70. if 'gifs' in node_output:
  71. output_images[node_id] = process_videos(node_output['gifs'], config)
  72. completed = True
  73. if node_id == '246':
  74. history_prompt = node_output['string'][0]
  75. completed = True
  76. completed = True
  77. if time.time() - start_time > timeout:
  78. logger.warning(f'Fallback to history by timeout')
  79. completed = True
  80. logger.info(f"Time cost: {time.time() - start_time}")
  81. except Exception as e:
  82. pass
  83. logger.info('Image acquisition completed')
  84. return output_images, history_prompt