""" 线稿图生成模块 功能: 1. 从服装款式图生成平铺图 2. 使用 ControlNet linear 模型从平铺图提取线稿 3. 自动质量检查 4. 批量处理图片 """ import os import json import glob import io import tempfile import requests from datetime import datetime from typing import List, Optional from PIL import Image from .check import process_image_pair_with_gemini from .conf import check_prompt from .upload_tos import process_cropped_upload, upload_image from .qwen_edit import qwen_edit from .llm import llm_request from .conf import ali_ky from .prompt import flat_layout_prompt_v2 from .logger_setup import logger # ==================== 常量定义 ==================== DEFAULT_SKETCH_DIR = r"D:\线稿图\线稿图" DEFAULT_LOG_FILE = "sketch_log.json" # 默认线稿图生成提示词 DEFAULT_SKETCH_PROMPT = ( "生成图片里衣服(如果有内衬则包括内衬)的服装平面款式图," "要平铺效果的线稿,仅仅保留外部轮廓和关键结构," "无多余拼接线,去除颜色,保持原比例,去除褶皱," "不要添加或删减元素,保持衣服的细节,只保留衣服的线稿," "去除其他非衣服之外的元素" ) IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.JPG', '.JPEG', '.PNG', '.BMP', '.GIF'] # 最大重试次数 MAX_RETRY_COUNT = 2 # ==================== 平铺图生成函数 ==================== def generate_flat_layout_from_url( image_url: str, prompt: Optional[str] = None ) -> Optional[str]: """ 从图片URL生成平铺图 Args: image_url: 款式图片URL prompt: 平铺图生成提示词,如果为None则自动生成 Returns: 生成的平铺图URL,失败返回None """ try: # 下载图片到临时文件 response = requests.get(image_url, timeout=30) response.raise_for_status() # 创建临时文件 with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file: tmp_file.write(response.content) tmp_path = tmp_file.name try: # 如果没有提供提示词,使用LLM生成 if prompt is None: logger.info("自动生成平铺图提示词...") llm = llm_request(api_key=ali_ky[0], base_url=ali_ky[1], model="qwen3-vl-plus") prompt = llm.llm_mm_request( usr_text="帮我生成这条衣服的平铺图指令", img=tmp_path, sys_text=flat_layout_prompt_v2 ) logger.info(f"生成的提示词: {prompt}") # 创建临时文件保存平铺图 with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as flat_tmp: flat_layout_path = flat_tmp.name # 使用qwen_edit生成平铺图 logger.info("开始生成平铺图...") qwen_edit(tmp_path, prompt, flat_layout_path) # 检查文件是否生成成功 if not os.path.exists(flat_layout_path) or os.path.getsize(flat_layout_path) == 0: logger.error("平铺图生成失败") return None # 上传平铺图到TOS flat_layout_url = upload_image(flat_layout_path) # 清理临时文件 try: os.unlink(tmp_path) os.unlink(flat_layout_path) except: pass if flat_layout_url: logger.info(f"✅ 平铺图生成成功: {flat_layout_url}") return flat_layout_url else: logger.error("上传平铺图失败") return None except Exception as e: logger.error(f"生成平铺图时出错: {e}") # 清理临时文件 try: os.unlink(tmp_path) except: pass return None except Exception as e: logger.error(f"下载图片失败: {e}") return None # ==================== ControlNet Linear 线稿提取函数 ==================== def extract_lineart_with_controlnet( image_url: str, controlnet_model: str = "lineart", api_key: Optional[str] = None ) -> Optional[Image.Image]: """ 使用 ControlNet linear 模型从图片中提取线稿 Args: image_url: 输入图片URL(平铺图) controlnet_model: ControlNet模型类型,默认为 "lineart" api_key: API密钥(可选,从环境变量获取) Returns: PIL Image对象(线稿图),失败返回None """ try: # 尝试使用 FAL API(如果可用) fal_key = api_key or os.environ.get("FAL_KEY") if fal_key: return _extract_lineart_with_fal(image_url, fal_key, controlnet_model) else: # 如果没有 FAL API,使用其他 ControlNet 服务 # 这里可以添加其他 ControlNet API 调用 logger.warning("FAL_KEY 未设置,尝试使用本地 ControlNet 处理") return _extract_lineart_local(image_url) except Exception as e: logger.error(f"提取线稿时出错: {e}") return None def _extract_lineart_with_fal( image_url: str, api_key: str, controlnet_model: str = "lineart" ) -> Optional[Image.Image]: """ 使用 FAL API 的 ControlNet 服务提取线稿 Args: image_url: 输入图片URL api_key: FAL API密钥 controlnet_model: ControlNet模型类型 Returns: PIL Image对象,失败返回None """ try: import fal_client logger.info(f"使用 FAL API 提取线稿: {image_url}") # 调用 FAL API 的 ControlNet lineart 模型 result = fal_client.subscribe( "fal-ai/controlnet-lineart", arguments={ "image_url": image_url, "model": controlnet_model }, api_key=api_key ) # 获取结果图片URL if result and "images" in result: output_url = result["images"][0].get("url") if isinstance(result["images"], list) else result["images"].get("url") if output_url: # 下载生成的线稿图 response = requests.get(output_url, timeout=30) response.raise_for_status() img = Image.open(io.BytesIO(response.content)).convert("RGB") logger.info("✅ 线稿提取成功") return img logger.error("FAL API 返回结果格式错误") return None except ImportError: logger.warning("fal_client 未安装,尝试使用 HTTP 请求") return _extract_lineart_with_http(image_url, api_key, controlnet_model) except Exception as e: logger.error(f"FAL API 调用失败: {e}") return None def _extract_lineart_with_http( image_url: str, api_key: str, controlnet_model: str = "lineart" ) -> Optional[Image.Image]: """ 使用 HTTP 请求调用 FAL API 提取线稿 Args: image_url: 输入图片URL api_key: FAL API密钥 controlnet_model: ControlNet模型类型 Returns: PIL Image对象,失败返回None """ try: url = "https://fal.run/fal-ai/controlnet-lineart" headers = { "Authorization": f"Key {api_key}", "Content-Type": "application/json" } payload = { "image_url": image_url, "model": controlnet_model } response = requests.post(url, json=payload, headers=headers, timeout=60) response.raise_for_status() result = response.json() # 获取结果图片URL if result and "images" in result: output_url = result["images"][0].get("url") if isinstance(result["images"], list) else result["images"].get("url") if output_url: # 下载生成的线稿图 img_response = requests.get(output_url, timeout=30) img_response.raise_for_status() img = Image.open(io.BytesIO(img_response.content)).convert("RGB") logger.info("✅ 线稿提取成功") return img logger.error("HTTP API 返回结果格式错误") return None except Exception as e: logger.error(f"HTTP API 调用失败: {e}") return None def _extract_lineart_local(image_url: str) -> Optional[Image.Image]: """ 使用本地方法提取线稿(简单的边缘检测) 注意:这是一个备用方案,效果不如 ControlNet 如果可能,建议使用 FAL API 或其他 ControlNet 服务 Args: image_url: 输入图片URL Returns: PIL Image对象,失败返回None """ try: import cv2 import numpy as np logger.info("使用本地方法提取线稿(边缘检测)...") # 下载图片 response = requests.get(image_url, timeout=30) response.raise_for_status() # 将字节数据转换为numpy数组 img_array = np.asarray(bytearray(response.content), dtype=np.uint8) img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) # 转换为灰度图 gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # 使用 Canny 边缘检测 edges = cv2.Canny(gray, 50, 150) # 反转颜色(黑底白线 -> 白底黑线) edges = 255 - edges # 转换回 PIL Image pil_img = Image.fromarray(edges).convert("RGB") logger.info("✅ 本地线稿提取完成(使用边缘检测)") return pil_img except ImportError: logger.error("OpenCV 未安装,无法使用本地方法") return None except Exception as e: logger.error(f"本地线稿提取失败: {e}") return None # ==================== 核心功能函数 ==================== def generate_sketch( image_url: str, prompt: Optional[str] = None, save_dir: Optional[str] = None, max_retries: int = MAX_RETRY_COUNT, auto_check: bool = True, flat_layout_prompt: Optional[str] = None, controlnet_model: str = "lineart" ) -> Optional[str]: """ 生成服装线稿图(新流程:先生成平铺图,再提取线稿) 流程: 1. 从款式图生成平铺图(使用 qwen_edit) 2. 使用 ControlNet linear 模型从平铺图提取线稿 3. 质量检查(可选) Args: image_url: 款式图片URL prompt: 线稿图提示词(已废弃,保留用于兼容) save_dir: 保存目录,如果为None则不保存到本地 max_retries: 最大重试次数 auto_check: 是否自动质量检查 flat_layout_prompt: 平铺图生成提示词,如果为None则自动生成 controlnet_model: ControlNet模型类型,默认为 "lineart" Returns: 生成的线稿图URL,失败返回None """ if not image_url: logger.error("图片URL不能为空") return None logger.info(f"开始生成线稿图(新流程): {image_url}") logger.info("流程:款式图 → 平铺图 → 线稿图") # 尝试生成线稿图,最多重试max_retries次 for attempt in range(max_retries): logger.info(f"第 {attempt + 1}/{max_retries} 次尝试生成线稿图") try: # 第一步:生成平铺图 logger.info("=" * 50) logger.info("步骤 1/2: 生成平铺图") logger.info("=" * 50) flat_layout_url = generate_flat_layout_from_url( image_url=image_url, prompt=flat_layout_prompt ) if flat_layout_url is None: logger.warning(f"第 {attempt + 1} 次生成平铺图失败,继续重试") continue logger.info(f"✅ 平铺图生成成功: {flat_layout_url}") # 第二步:使用 ControlNet linear 提取线稿 logger.info("=" * 50) logger.info("步骤 2/2: 使用 ControlNet linear 提取线稿") logger.info("=" * 50) sketch_image = extract_lineart_with_controlnet( image_url=flat_layout_url, controlnet_model=controlnet_model ) if sketch_image is None: logger.warning(f"第 {attempt + 1} 次提取线稿失败,继续重试") continue # 上传线稿图获取URL sketch_url = process_cropped_upload(sketch_image) if sketch_url is None: logger.warning(f"第 {attempt + 1} 次上传线稿图失败,继续重试") continue logger.info(f"✅ 线稿图提取成功: {sketch_url}") # 如果启用自动检查,进行质量验证 if auto_check: logger.info("=" * 50) logger.info("进行质量检查...") logger.info("=" * 50) check_result = process_image_pair_with_gemini( image1_url=image_url, image2_url=sketch_url, prompt=check_prompt ) if check_result: check_result = check_result.strip() # 检查是否通过(回答"是") if check_result and ("是" in check_result or check_result[0] == "是"): logger.info(f"✅ 质量检查通过: {check_result}") logger.info(f"✅ 线稿图生成成功: {sketch_url}") return sketch_url else: logger.warning(f"⚠️ 质量检查未通过: {check_result}") if attempt < max_retries - 1: logger.info("继续重试...") continue else: logger.warning("质量检查失败,但继续使用生成的图片") # 如果没有启用检查或检查失败但达到最大重试次数,返回结果 logger.info(f"✅ 线稿图生成成功: {sketch_url}") return sketch_url except Exception as e: logger.error(f"第 {attempt + 1} 次尝试时出错: {e}") if attempt < max_retries - 1: continue else: logger.error(f"❌ 生成线稿图失败: {e}") return None logger.error("❌ 达到最大重试次数,生成线稿图失败") return None def generate_sketch_from_local( image_path: str, prompt: Optional[str] = None, save_dir: Optional[str] = None, max_retries: int = MAX_RETRY_COUNT, auto_check: bool = True, model: str = "gemini-2.5-flash-image", resolution: str = "1K" ) -> Optional[str]: """ 从本地图片文件生成线稿图 首先上传本地图片到TOS获取URL,然后调用generate_sketch Args: image_path: 本地图片路径 prompt: 提示词 save_dir: 保存目录 max_retries: 最大重试次数 auto_check: 是否自动质量检查 model: 使用的模型 resolution: 分辨率 Returns: 生成的线稿图URL,失败返回None """ if not os.path.exists(image_path): logger.error(f"图片文件不存在: {image_path}") return None try: # 上传本地图片获取URL from .upload_tos import upload_image image_url = upload_image(image_path) if image_url is None: logger.error("上传图片失败") return None logger.info(f"图片已上传: {image_url}") # 调用生成函数 return generate_sketch( image_url=image_url, prompt=prompt, save_dir=save_dir, max_retries=max_retries, auto_check=auto_check, model=model, resolution=resolution ) except Exception as e: logger.error(f"处理本地图片时出错: {e}") return None # ==================== 工具函数 ==================== def get_image_files(directory: str, extensions: Optional[List[str]] = None) -> List[str]: """ 获取目录下所有图片文件 Args: directory: 目录路径 extensions: 图片扩展名列表 Returns: 图片文件路径列表(已排序) """ if extensions is None: extensions = IMAGE_EXTENSIONS image_files = [] if not os.path.exists(directory): return image_files for ext in extensions: pattern = os.path.join(directory, f'*{ext}') image_files.extend(glob.glob(pattern)) return sorted(image_files) def save_sketch_log(image_url: str, sketch_url: str, prompt: str, log_file: str = DEFAULT_LOG_FILE) -> None: """ 保存线稿图生成记录到JSON文件 Args: image_url: 原始图片URL sketch_url: 生成的线稿图URL prompt: 使用的提示词 log_file: 日志文件路径 """ log_data = { "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "image_url": image_url, "sketch_url": sketch_url, "prompt": prompt } # 读取现有数据 if os.path.exists(log_file): try: with open(log_file, 'r', encoding='utf-8') as f: data = json.load(f) except (json.JSONDecodeError, FileNotFoundError): data = [] else: data = [] # 添加新记录 data.append(log_data) # 保存到文件 with open(log_file, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) logger.info(f"✅ 已保存记录到 {log_file}") # ==================== 批量处理功能 ==================== def batch_generate_sketch_from_urls( image_urls: List[str], prompt: Optional[str] = None, max_retries: int = MAX_RETRY_COUNT, auto_check: bool = True, log_file: Optional[str] = None ) -> dict: """ 批量从图片URL列表生成线稿图 Args: image_urls: 图片URL列表 prompt: 提示词 max_retries: 最大重试次数 auto_check: 是否自动质量检查 log_file: 日志文件路径,如果为None则不保存 Returns: 处理结果字典,包含成功和失败的数量 """ logger.info(f"开始批量处理 {len(image_urls)} 个图片") success_count = 0 fail_count = 0 all_results = [] for idx, image_url in enumerate(image_urls, 1): logger.info(f"\n{'='*50}") logger.info(f"处理第 {idx}/{len(image_urls)} 个图片") logger.info(f"{'='*50}") try: sketch_url = generate_sketch( image_url=image_url, prompt=prompt, max_retries=max_retries, auto_check=auto_check ) if sketch_url: result = { "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "image_url": image_url, "sketch_url": sketch_url, "prompt": prompt or DEFAULT_SKETCH_PROMPT, "success": True } all_results.append(result) success_count += 1 # 保存日志 if log_file: save_sketch_log(image_url, sketch_url, prompt or DEFAULT_SKETCH_PROMPT, log_file) else: result = { "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "image_url": image_url, "success": False, "error": "生成失败" } all_results.append(result) fail_count += 1 except Exception as e: logger.error(f"处理图片时出错: {e}") result = { "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "image_url": image_url, "success": False, "error": str(e) } all_results.append(result) fail_count += 1 continue # 保存结果到文件 if log_file: try: with open(log_file, 'w', encoding='utf-8') as f: json.dump(all_results, f, ensure_ascii=False, indent=2) logger.info(f"✅ 已保存处理结果到: {log_file}") except Exception as e: logger.error(f"保存文件时出错: {e}") logger.info(f"\n{'='*50}") logger.info(f"✅ 批量处理完成!") logger.info(f" 成功: {success_count} 个") logger.info(f" 失败: {fail_count} 个") logger.info(f"{'='*50}") return { "success_count": success_count, "fail_count": fail_count, "total": len(image_urls), "results": all_results } def batch_generate_sketch_from_directory( directory: str, prompt: Optional[str] = None, max_retries: int = MAX_RETRY_COUNT, auto_check: bool = True, log_file: Optional[str] = None ) -> dict: """ 从目录中扫描图片,批量生成线稿图 Args: directory: 图片目录路径 prompt: 提示词 max_retries: 最大重试次数 auto_check: 是否自动质量检查 log_file: 日志文件路径 Returns: 处理结果字典 """ # 获取所有图片文件 image_files = get_image_files(directory) if not image_files: logger.warning(f"在目录 {directory} 中未找到图片文件") return { "success_count": 0, "fail_count": 0, "total": 0, "results": [] } logger.info(f"找到 {len(image_files)} 个图片文件") # 上传所有图片获取URL from .upload_tos import upload_image image_urls = [] for image_path in image_files: try: image_url = upload_image(image_path) if image_url: image_urls.append(image_url) except Exception as e: logger.error(f"上传图片失败 {image_path}: {e}") continue logger.info(f"成功上传 {len(image_urls)} 个图片") # 批量生成线稿图 return batch_generate_sketch_from_urls( image_urls=image_urls, prompt=prompt, max_retries=max_retries, auto_check=auto_check, log_file=log_file ) # ==================== 主程序入口 ==================== if __name__ == "__main__": # 示例:从单个图片URL生成线稿图 test_image_url = "https://example.com/garment.jpg" result = generate_sketch(test_image_url) if result: print(f"✅ 线稿图生成成功: {result}") else: print("❌ 线稿图生成失败")