| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738 |
- """
- 线稿图生成模块
- 功能:
- 1. 从服装款式图生成平铺图
- 2. 使用 ControlNet linear 模型从平铺图提取线稿
- 3. 自动质量检查
- 4. 批量处理图片
- """
- import os
- import json
- import glob
- import io
- import tempfile
- import requests
- from datetime import datetime
- from typing import List, Optional
- from PIL import Image
- from .check import process_image_pair_with_gemini
- from .conf import check_prompt
- from .upload_tos import process_cropped_upload, upload_image
- from .qwen_edit import qwen_edit
- from .llm import llm_request
- from .conf import ali_ky
- from .prompt import flat_layout_prompt_v2
- from .logger_setup import logger
- # ==================== 常量定义 ====================
- DEFAULT_SKETCH_DIR = r"D:\线稿图\线稿图"
- DEFAULT_LOG_FILE = "sketch_log.json"
- # 默认线稿图生成提示词
- DEFAULT_SKETCH_PROMPT = (
- "生成图片里衣服(如果有内衬则包括内衬)的服装平面款式图,"
- "要平铺效果的线稿,仅仅保留外部轮廓和关键结构,"
- "无多余拼接线,去除颜色,保持原比例,去除褶皱,"
- "不要添加或删减元素,保持衣服的细节,只保留衣服的线稿,"
- "去除其他非衣服之外的元素"
- )
- IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.bmp', '.gif',
- '.JPG', '.JPEG', '.PNG', '.BMP', '.GIF']
- # 最大重试次数
- MAX_RETRY_COUNT = 2
- # ==================== 平铺图生成函数 ====================
- def generate_flat_layout_from_url(
- image_url: str,
- prompt: Optional[str] = None
- ) -> Optional[str]:
- """
- 从图片URL生成平铺图
-
- Args:
- image_url: 款式图片URL
- prompt: 平铺图生成提示词,如果为None则自动生成
-
- Returns:
- 生成的平铺图URL,失败返回None
- """
- try:
- # 下载图片到临时文件
- response = requests.get(image_url, timeout=30)
- response.raise_for_status()
-
- # 创建临时文件
- with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file:
- tmp_file.write(response.content)
- tmp_path = tmp_file.name
-
- try:
- # 如果没有提供提示词,使用LLM生成
- if prompt is None:
- logger.info("自动生成平铺图提示词...")
- llm = llm_request(api_key=ali_ky[0], base_url=ali_ky[1], model="qwen3-vl-plus")
- prompt = llm.llm_mm_request(
- usr_text="帮我生成这条衣服的平铺图指令",
- img=tmp_path,
- sys_text=flat_layout_prompt_v2
- )
- logger.info(f"生成的提示词: {prompt}")
-
- # 创建临时文件保存平铺图
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as flat_tmp:
- flat_layout_path = flat_tmp.name
-
- # 使用qwen_edit生成平铺图
- logger.info("开始生成平铺图...")
- qwen_edit(tmp_path, prompt, flat_layout_path)
-
- # 检查文件是否生成成功
- if not os.path.exists(flat_layout_path) or os.path.getsize(flat_layout_path) == 0:
- logger.error("平铺图生成失败")
- return None
-
- # 上传平铺图到TOS
- flat_layout_url = upload_image(flat_layout_path)
-
- # 清理临时文件
- try:
- os.unlink(tmp_path)
- os.unlink(flat_layout_path)
- except:
- pass
-
- if flat_layout_url:
- logger.info(f"✅ 平铺图生成成功: {flat_layout_url}")
- return flat_layout_url
- else:
- logger.error("上传平铺图失败")
- return None
-
- except Exception as e:
- logger.error(f"生成平铺图时出错: {e}")
- # 清理临时文件
- try:
- os.unlink(tmp_path)
- except:
- pass
- return None
-
- except Exception as e:
- logger.error(f"下载图片失败: {e}")
- return None
- # ==================== ControlNet Linear 线稿提取函数 ====================
- def extract_lineart_with_controlnet(
- image_url: str,
- controlnet_model: str = "lineart",
- api_key: Optional[str] = None
- ) -> Optional[Image.Image]:
- """
- 使用 ControlNet linear 模型从图片中提取线稿
-
- Args:
- image_url: 输入图片URL(平铺图)
- controlnet_model: ControlNet模型类型,默认为 "lineart"
- api_key: API密钥(可选,从环境变量获取)
-
- Returns:
- PIL Image对象(线稿图),失败返回None
- """
- try:
- # 尝试使用 FAL API(如果可用)
- fal_key = api_key or os.environ.get("FAL_KEY")
-
- if fal_key:
- return _extract_lineart_with_fal(image_url, fal_key, controlnet_model)
- else:
- # 如果没有 FAL API,使用其他 ControlNet 服务
- # 这里可以添加其他 ControlNet API 调用
- logger.warning("FAL_KEY 未设置,尝试使用本地 ControlNet 处理")
- return _extract_lineart_local(image_url)
-
- except Exception as e:
- logger.error(f"提取线稿时出错: {e}")
- return None
- def _extract_lineart_with_fal(
- image_url: str,
- api_key: str,
- controlnet_model: str = "lineart"
- ) -> Optional[Image.Image]:
- """
- 使用 FAL API 的 ControlNet 服务提取线稿
-
- Args:
- image_url: 输入图片URL
- api_key: FAL API密钥
- controlnet_model: ControlNet模型类型
-
- Returns:
- PIL Image对象,失败返回None
- """
- try:
- import fal_client
-
- logger.info(f"使用 FAL API 提取线稿: {image_url}")
-
- # 调用 FAL API 的 ControlNet lineart 模型
- result = fal_client.subscribe(
- "fal-ai/controlnet-lineart",
- arguments={
- "image_url": image_url,
- "model": controlnet_model
- },
- api_key=api_key
- )
-
- # 获取结果图片URL
- if result and "images" in result:
- output_url = result["images"][0].get("url") if isinstance(result["images"], list) else result["images"].get("url")
-
- if output_url:
- # 下载生成的线稿图
- response = requests.get(output_url, timeout=30)
- response.raise_for_status()
- img = Image.open(io.BytesIO(response.content)).convert("RGB")
- logger.info("✅ 线稿提取成功")
- return img
-
- logger.error("FAL API 返回结果格式错误")
- return None
-
- except ImportError:
- logger.warning("fal_client 未安装,尝试使用 HTTP 请求")
- return _extract_lineart_with_http(image_url, api_key, controlnet_model)
- except Exception as e:
- logger.error(f"FAL API 调用失败: {e}")
- return None
- def _extract_lineart_with_http(
- image_url: str,
- api_key: str,
- controlnet_model: str = "lineart"
- ) -> Optional[Image.Image]:
- """
- 使用 HTTP 请求调用 FAL API 提取线稿
-
- Args:
- image_url: 输入图片URL
- api_key: FAL API密钥
- controlnet_model: ControlNet模型类型
-
- Returns:
- PIL Image对象,失败返回None
- """
- try:
- url = "https://fal.run/fal-ai/controlnet-lineart"
- headers = {
- "Authorization": f"Key {api_key}",
- "Content-Type": "application/json"
- }
- payload = {
- "image_url": image_url,
- "model": controlnet_model
- }
-
- response = requests.post(url, json=payload, headers=headers, timeout=60)
- response.raise_for_status()
-
- result = response.json()
-
- # 获取结果图片URL
- if result and "images" in result:
- output_url = result["images"][0].get("url") if isinstance(result["images"], list) else result["images"].get("url")
-
- if output_url:
- # 下载生成的线稿图
- img_response = requests.get(output_url, timeout=30)
- img_response.raise_for_status()
- img = Image.open(io.BytesIO(img_response.content)).convert("RGB")
- logger.info("✅ 线稿提取成功")
- return img
-
- logger.error("HTTP API 返回结果格式错误")
- return None
-
- except Exception as e:
- logger.error(f"HTTP API 调用失败: {e}")
- return None
- def _extract_lineart_local(image_url: str) -> Optional[Image.Image]:
- """
- 使用本地方法提取线稿(简单的边缘检测)
-
- 注意:这是一个备用方案,效果不如 ControlNet
- 如果可能,建议使用 FAL API 或其他 ControlNet 服务
-
- Args:
- image_url: 输入图片URL
-
- Returns:
- PIL Image对象,失败返回None
- """
- try:
- import cv2
- import numpy as np
-
- logger.info("使用本地方法提取线稿(边缘检测)...")
-
- # 下载图片
- response = requests.get(image_url, timeout=30)
- response.raise_for_status()
- # 将字节数据转换为numpy数组
- img_array = np.asarray(bytearray(response.content), dtype=np.uint8)
- img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
-
- # 转换为灰度图
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
-
- # 使用 Canny 边缘检测
- edges = cv2.Canny(gray, 50, 150)
-
- # 反转颜色(黑底白线 -> 白底黑线)
- edges = 255 - edges
-
- # 转换回 PIL Image
- pil_img = Image.fromarray(edges).convert("RGB")
-
- logger.info("✅ 本地线稿提取完成(使用边缘检测)")
- return pil_img
-
- except ImportError:
- logger.error("OpenCV 未安装,无法使用本地方法")
- return None
- except Exception as e:
- logger.error(f"本地线稿提取失败: {e}")
- return None
- # ==================== 核心功能函数 ====================
- def generate_sketch(
- image_url: str,
- prompt: Optional[str] = None,
- save_dir: Optional[str] = None,
- max_retries: int = MAX_RETRY_COUNT,
- auto_check: bool = True,
- flat_layout_prompt: Optional[str] = None,
- controlnet_model: str = "lineart"
- ) -> Optional[str]:
- """
- 生成服装线稿图(新流程:先生成平铺图,再提取线稿)
-
- 流程:
- 1. 从款式图生成平铺图(使用 qwen_edit)
- 2. 使用 ControlNet linear 模型从平铺图提取线稿
- 3. 质量检查(可选)
-
- Args:
- image_url: 款式图片URL
- prompt: 线稿图提示词(已废弃,保留用于兼容)
- save_dir: 保存目录,如果为None则不保存到本地
- max_retries: 最大重试次数
- auto_check: 是否自动质量检查
- flat_layout_prompt: 平铺图生成提示词,如果为None则自动生成
- controlnet_model: ControlNet模型类型,默认为 "lineart"
-
- Returns:
- 生成的线稿图URL,失败返回None
- """
- if not image_url:
- logger.error("图片URL不能为空")
- return None
-
- logger.info(f"开始生成线稿图(新流程): {image_url}")
- logger.info("流程:款式图 → 平铺图 → 线稿图")
-
- # 尝试生成线稿图,最多重试max_retries次
- for attempt in range(max_retries):
- logger.info(f"第 {attempt + 1}/{max_retries} 次尝试生成线稿图")
-
- try:
- # 第一步:生成平铺图
- logger.info("=" * 50)
- logger.info("步骤 1/2: 生成平铺图")
- logger.info("=" * 50)
-
- flat_layout_url = generate_flat_layout_from_url(
- image_url=image_url,
- prompt=flat_layout_prompt
- )
-
- if flat_layout_url is None:
- logger.warning(f"第 {attempt + 1} 次生成平铺图失败,继续重试")
- continue
-
- logger.info(f"✅ 平铺图生成成功: {flat_layout_url}")
-
- # 第二步:使用 ControlNet linear 提取线稿
- logger.info("=" * 50)
- logger.info("步骤 2/2: 使用 ControlNet linear 提取线稿")
- logger.info("=" * 50)
-
- sketch_image = extract_lineart_with_controlnet(
- image_url=flat_layout_url,
- controlnet_model=controlnet_model
- )
-
- if sketch_image is None:
- logger.warning(f"第 {attempt + 1} 次提取线稿失败,继续重试")
- continue
-
- # 上传线稿图获取URL
- sketch_url = process_cropped_upload(sketch_image)
- if sketch_url is None:
- logger.warning(f"第 {attempt + 1} 次上传线稿图失败,继续重试")
- continue
-
- logger.info(f"✅ 线稿图提取成功: {sketch_url}")
-
- # 如果启用自动检查,进行质量验证
- if auto_check:
- logger.info("=" * 50)
- logger.info("进行质量检查...")
- logger.info("=" * 50)
-
- check_result = process_image_pair_with_gemini(
- image1_url=image_url,
- image2_url=sketch_url,
- prompt=check_prompt
- )
-
- if check_result:
- check_result = check_result.strip()
- # 检查是否通过(回答"是")
- if check_result and ("是" in check_result or check_result[0] == "是"):
- logger.info(f"✅ 质量检查通过: {check_result}")
- logger.info(f"✅ 线稿图生成成功: {sketch_url}")
- return sketch_url
- else:
- logger.warning(f"⚠️ 质量检查未通过: {check_result}")
- if attempt < max_retries - 1:
- logger.info("继续重试...")
- continue
- else:
- logger.warning("质量检查失败,但继续使用生成的图片")
-
- # 如果没有启用检查或检查失败但达到最大重试次数,返回结果
- logger.info(f"✅ 线稿图生成成功: {sketch_url}")
- return sketch_url
-
- except Exception as e:
- logger.error(f"第 {attempt + 1} 次尝试时出错: {e}")
- if attempt < max_retries - 1:
- continue
- else:
- logger.error(f"❌ 生成线稿图失败: {e}")
- return None
-
- logger.error("❌ 达到最大重试次数,生成线稿图失败")
- return None
- def generate_sketch_from_local(
- image_path: str,
- prompt: Optional[str] = None,
- save_dir: Optional[str] = None,
- max_retries: int = MAX_RETRY_COUNT,
- auto_check: bool = True,
- model: str = "gemini-2.5-flash-image",
- resolution: str = "1K"
- ) -> Optional[str]:
- """
- 从本地图片文件生成线稿图
-
- 首先上传本地图片到TOS获取URL,然后调用generate_sketch
-
- Args:
- image_path: 本地图片路径
- prompt: 提示词
- save_dir: 保存目录
- max_retries: 最大重试次数
- auto_check: 是否自动质量检查
- model: 使用的模型
- resolution: 分辨率
-
- Returns:
- 生成的线稿图URL,失败返回None
- """
- if not os.path.exists(image_path):
- logger.error(f"图片文件不存在: {image_path}")
- return None
-
- try:
- # 上传本地图片获取URL
- from .upload_tos import upload_image
- image_url = upload_image(image_path)
-
- if image_url is None:
- logger.error("上传图片失败")
- return None
-
- logger.info(f"图片已上传: {image_url}")
-
- # 调用生成函数
- return generate_sketch(
- image_url=image_url,
- prompt=prompt,
- save_dir=save_dir,
- max_retries=max_retries,
- auto_check=auto_check,
- model=model,
- resolution=resolution
- )
-
- except Exception as e:
- logger.error(f"处理本地图片时出错: {e}")
- return None
- # ==================== 工具函数 ====================
- def get_image_files(directory: str, extensions: Optional[List[str]] = None) -> List[str]:
- """
- 获取目录下所有图片文件
-
- Args:
- directory: 目录路径
- extensions: 图片扩展名列表
-
- Returns:
- 图片文件路径列表(已排序)
- """
- if extensions is None:
- extensions = IMAGE_EXTENSIONS
-
- image_files = []
- if not os.path.exists(directory):
- return image_files
-
- for ext in extensions:
- pattern = os.path.join(directory, f'*{ext}')
- image_files.extend(glob.glob(pattern))
-
- return sorted(image_files)
- def save_sketch_log(image_url: str, sketch_url: str, prompt: str,
- log_file: str = DEFAULT_LOG_FILE) -> None:
- """
- 保存线稿图生成记录到JSON文件
-
- Args:
- image_url: 原始图片URL
- sketch_url: 生成的线稿图URL
- prompt: 使用的提示词
- log_file: 日志文件路径
- """
- log_data = {
- "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
- "image_url": image_url,
- "sketch_url": sketch_url,
- "prompt": prompt
- }
-
- # 读取现有数据
- if os.path.exists(log_file):
- try:
- with open(log_file, 'r', encoding='utf-8') as f:
- data = json.load(f)
- except (json.JSONDecodeError, FileNotFoundError):
- data = []
- else:
- data = []
-
- # 添加新记录
- data.append(log_data)
-
- # 保存到文件
- with open(log_file, 'w', encoding='utf-8') as f:
- json.dump(data, f, ensure_ascii=False, indent=2)
-
- logger.info(f"✅ 已保存记录到 {log_file}")
- # ==================== 批量处理功能 ====================
- def batch_generate_sketch_from_urls(
- image_urls: List[str],
- prompt: Optional[str] = None,
- max_retries: int = MAX_RETRY_COUNT,
- auto_check: bool = True,
- log_file: Optional[str] = None
- ) -> dict:
- """
- 批量从图片URL列表生成线稿图
-
- Args:
- image_urls: 图片URL列表
- prompt: 提示词
- max_retries: 最大重试次数
- auto_check: 是否自动质量检查
- log_file: 日志文件路径,如果为None则不保存
-
- Returns:
- 处理结果字典,包含成功和失败的数量
- """
- logger.info(f"开始批量处理 {len(image_urls)} 个图片")
-
- success_count = 0
- fail_count = 0
- all_results = []
-
- for idx, image_url in enumerate(image_urls, 1):
- logger.info(f"\n{'='*50}")
- logger.info(f"处理第 {idx}/{len(image_urls)} 个图片")
- logger.info(f"{'='*50}")
-
- try:
- sketch_url = generate_sketch(
- image_url=image_url,
- prompt=prompt,
- max_retries=max_retries,
- auto_check=auto_check
- )
-
- if sketch_url:
- result = {
- "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
- "image_url": image_url,
- "sketch_url": sketch_url,
- "prompt": prompt or DEFAULT_SKETCH_PROMPT,
- "success": True
- }
- all_results.append(result)
- success_count += 1
-
- # 保存日志
- if log_file:
- save_sketch_log(image_url, sketch_url, prompt or DEFAULT_SKETCH_PROMPT, log_file)
- else:
- result = {
- "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
- "image_url": image_url,
- "success": False,
- "error": "生成失败"
- }
- all_results.append(result)
- fail_count += 1
-
- except Exception as e:
- logger.error(f"处理图片时出错: {e}")
- result = {
- "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
- "image_url": image_url,
- "success": False,
- "error": str(e)
- }
- all_results.append(result)
- fail_count += 1
- continue
-
- # 保存结果到文件
- if log_file:
- try:
- with open(log_file, 'w', encoding='utf-8') as f:
- json.dump(all_results, f, ensure_ascii=False, indent=2)
- logger.info(f"✅ 已保存处理结果到: {log_file}")
- except Exception as e:
- logger.error(f"保存文件时出错: {e}")
-
- logger.info(f"\n{'='*50}")
- logger.info(f"✅ 批量处理完成!")
- logger.info(f" 成功: {success_count} 个")
- logger.info(f" 失败: {fail_count} 个")
- logger.info(f"{'='*50}")
-
- return {
- "success_count": success_count,
- "fail_count": fail_count,
- "total": len(image_urls),
- "results": all_results
- }
- def batch_generate_sketch_from_directory(
- directory: str,
- prompt: Optional[str] = None,
- max_retries: int = MAX_RETRY_COUNT,
- auto_check: bool = True,
- log_file: Optional[str] = None
- ) -> dict:
- """
- 从目录中扫描图片,批量生成线稿图
-
- Args:
- directory: 图片目录路径
- prompt: 提示词
- max_retries: 最大重试次数
- auto_check: 是否自动质量检查
- log_file: 日志文件路径
-
- Returns:
- 处理结果字典
- """
- # 获取所有图片文件
- image_files = get_image_files(directory)
-
- if not image_files:
- logger.warning(f"在目录 {directory} 中未找到图片文件")
- return {
- "success_count": 0,
- "fail_count": 0,
- "total": 0,
- "results": []
- }
-
- logger.info(f"找到 {len(image_files)} 个图片文件")
-
- # 上传所有图片获取URL
- from .upload_tos import upload_image
- image_urls = []
-
- for image_path in image_files:
- try:
- image_url = upload_image(image_path)
- if image_url:
- image_urls.append(image_url)
- except Exception as e:
- logger.error(f"上传图片失败 {image_path}: {e}")
- continue
-
- logger.info(f"成功上传 {len(image_urls)} 个图片")
-
- # 批量生成线稿图
- return batch_generate_sketch_from_urls(
- image_urls=image_urls,
- prompt=prompt,
- max_retries=max_retries,
- auto_check=auto_check,
- log_file=log_file
- )
- # ==================== 主程序入口 ====================
- if __name__ == "__main__":
- # 示例:从单个图片URL生成线稿图
- test_image_url = "https://example.com/garment.jpg"
- result = generate_sketch(test_image_url)
- if result:
- print(f"✅ 线稿图生成成功: {result}")
- else:
- print("❌ 线稿图生成失败")
|