image_output.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import base64
  2. import cv2
  3. from typing import List, Literal, Optional, Union
  4. from PIL import Image
  5. from utils.tools import download_image
  6. class ImageOutput:
  7. fmt: Literal["b64", "url", "pil", "np"]
  8. ext: str = "png"
  9. data: Union[str, Image.Image]
  10. def __init__(
  11. self,
  12. fmt: Literal["b64", "url", "pil", "np"],
  13. ext: str,
  14. data: Union[str, Image.Image],
  15. ):
  16. self.fmt = fmt
  17. self.ext = ext
  18. self.data = data
  19. def save_b64(self, path: str) -> None:
  20. """Save a base64 encoded image to the specified path.
  21. Args:
  22. path (str): Path where the image will be saved.
  23. """
  24. with open(path, 'wb') as f:
  25. f.write(base64.b64decode(self.data))
  26. def save_url(self, path: str) -> None:
  27. """Download and save an image from a URL to the specified path.
  28. Args:
  29. path (str): Path where the image will be saved.
  30. """
  31. download_image(self.data, path)
  32. def save_pil(self, path: str) -> None:
  33. """Save a PIL Image to the specified path.
  34. Args:
  35. path (str): Path where the image will be saved.
  36. """
  37. self.data.save(path)
  38. def save_np(self, path: str) -> None:
  39. """Save a numpy array to the specified path.
  40. Args:
  41. path (str): Path where the image will be saved.
  42. """
  43. cv2.imencode('.png', self.data)[1].tofile(path)
  44. def save(self, path: str) -> None:
  45. save_func = getattr(self, f"save_{self.fmt}")
  46. save_func(path)
  47. def save_img(self, path: str) -> None:
  48. with open(path, "wb") as f:
  49. f.write(self.data)