depth_img.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import os
  2. from transformers import pipeline
  3. from PIL import Image
  4. from tqdm import tqdm
  5. class DepthEstimator:
  6. def __init__(self, model_path="models/Depth-Anything-V2-Small-hf"):
  7. """
  8. 初始化深度估计器,加载模型
  9. Args:
  10. model_path (str): 深度估计模型的路径
  11. """
  12. self.pipe = pipeline(task="depth-estimation", model=model_path)
  13. def estimate_depth(self, image_path):
  14. """
  15. 估计图像的深度并保存深度图
  16. Args:
  17. image_path (str): 输入图像的路径
  18. Returns:
  19. PIL.Image.Image: 生成的深度图
  20. """
  21. # 打开输入图像
  22. image = Image.open(image_path)
  23. # 进行深度估计
  24. depth = self.pipe(image)["depth"]
  25. # 保存深度图
  26. output_depth_path = f"data/depth_image/depth-{os.path.splitext(os.path.basename(image_path))[0]}.png"
  27. depth.save(output_depth_path)
  28. return depth
  29. def depth_image_list(image_paths):
  30. """
  31. 处理图像列表,返回深度图像列表
  32. Args:
  33. image_paths (list): 输入图像路径列表
  34. Returns:
  35. list: 深度图像列表
  36. """
  37. depth_estimator = DepthEstimator() # 创建深度估计器实例
  38. depth_images = []
  39. for image_path in tqdm(image_paths):
  40. depth_image = depth_estimator.estimate_depth(image_path) # 估计深度
  41. depth_images.append(depth_image) # 将深度图像添加到列表中
  42. return depth_images
  43. # 使用示例
  44. if __name__ == "__main__":
  45. #model_path = "/data/data/luosy/project/oral/models/Depth-Anything-V2-Small-hf" # 模型路径
  46. input_image_path = 'masked_output_image.png' # 输入图像路径
  47. depth_estimator = DepthEstimator() # 创建深度估计器实例
  48. depth_image = depth_estimator.estimate_depth(input_image_path) # 估计深度
  49. print(type(depth_image))