import React, { useState, useEffect, useMemo } from 'react';
import { Button } from '@/components/ui/button';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Label } from '@/components/ui/label';
import { Input } from '@/components/ui/input';
import { Badge } from '@/components/ui/badge';
import { Alert, AlertDescription } from '@/components/ui/alert';
import { Progress } from '@/components/ui/progress';
import { RadioGroup, RadioGroupItem } from '@/components/ui/radio-group';
import { ScrollArea } from '@/components/ui/scroll-area';
import { Separator } from '@/components/ui/separator';
import {
User,
Play,
Loader2,
CheckCircle,
Wand2,
Sparkles,
Image as ImageIcon,
FileText,
Zap,
Save,
Upload,
Plus,
Settings,
Palette,
Camera,
Shirt,
ImageIcon as ImageIcon2
} from 'lucide-react';
import { motion, AnimatePresence } from 'framer-motion';
import { materialLibraryAPI } from '@/api/materials';
import { textTemplateAPI } from '@/api/textTemplates';
import { aiSwapAPI } from '@/api/ai_swap';
import { aiSwapBgAPI } from '@/api/ai_swap_bg';
import { useAuth } from '@/contexts/AuthContext';
// 素材选择卡片组件
const MaterialCard = ({ item, isSelected, onSelect, type = 'image' }) => (
onSelect(item.id)}
>
{item.file_url ? (
) : (
暂无图片
)}
{isSelected && (
)}
);
// 模板选择组件
const TemplateItem = ({ template, isSelected, onSelect, type }) => (
onSelect(template.id)}
>
{template.name}
{isSelected && }
{template.tag || type}
);
export default function AIGeneration() {
const { user } = useAuth();
// 状态管理
const [originalImages, setOriginalImages] = useState([]);
const [ipCharacters, setIpCharacters] = useState([]);
const [clothingItems, setClothingItems] = useState([]);
const [sceneTemplates, setSceneTemplates] = useState([]);
const [copyTemplates, setCopyTemplates] = useState([]);
const [selectedOriginalImages, setSelectedOriginalImages] = useState([]);
const [selectedIpCharacters, setSelectedIpCharacters] = useState([]);
const [selectedClothingItems, setSelectedClothingItems] = useState([]);
const [selectedSceneTemplates, setSelectedSceneTemplates] = useState([]);
const [selectedCopyTemplates, setSelectedCopyTemplates] = useState([]);
const [taskName, setTaskName] = useState('');
const [selectedTaskTypes, setSelectedTaskTypes] = useState([]);
const [quantityPerGroup, setQuantityPerGroup] = useState(1);
const [isGenerating, setIsGenerating] = useState(false);
const [generationProgress, setGenerationProgress] = useState(0);
const [currentTask, setCurrentTask] = useState('');
const [message, setMessage] = useState(null);
const [loading, setLoading] = useState(false);
const [currentTaskId, setCurrentTaskId] = useState(null); // 兼容保留,展示第一个任务ID
const [currentTaskIds, setCurrentTaskIds] = useState([]); // 多任务ID
const [taskStatus, setTaskStatus] = useState(null); // 聚合状态摘要
const [pollingInterval, setPollingInterval] = useState(null);
useEffect(() => {
if (user) {
loadMaterials();
loadTemplates();
}
}, [user]);
// 清理轮询定时器
useEffect(() => {
return () => {
if (pollingInterval) {
clearInterval(pollingInterval);
}
};
}, [pollingInterval]);
const loadMaterials = async () => {
// 使用当前登录用户的ID
const userId = user?.id;
if (!userId) {
setMessage({ type: 'error', text: '请先登录' });
return;
}
setLoading(true);
try {
// 获取原始素材(原图)- 对应素材库的"原始素材"标签页
const originalResponse = await materialLibraryAPI.getMaterials(userId, 'original');
console.log('原始素材API响应:', originalResponse);
if (originalResponse.success) {
const originalData = originalResponse.images.map(item => ({
...item,
file_url: getImageUrl(item)
}));
setOriginalImages(originalData);
console.log('原始素材数据:', originalData);
}
// 获取IP素材(人脸)- 对应素材库的"IP素材"标签页
const faceResponse = await materialLibraryAPI.getMaterials(userId, 'face');
console.log('IP素材API响应:', faceResponse);
if (faceResponse.success) {
const faceData = faceResponse.images.map(item => ({
...item,
file_url: getImageUrl(item)
}));
setIpCharacters(faceData);
console.log('IP素材数据:', faceData);
}
// 获取产品素材(服装)- 对应素材库的"产品素材"标签页,但API类型是'cloth'
const clothResponse = await materialLibraryAPI.getMaterials(userId, 'cloth');
console.log('产品素材API响应:', clothResponse);
if (clothResponse.success) {
const clothData = clothResponse.images.map(item => ({
...item,
file_url: getImageUrl(item)
}));
setClothingItems(clothData);
console.log('产品素材数据:', clothData);
}
} catch (error) {
console.error('加载素材失败:', error);
setMessage({ type: 'error', text: '加载素材失败,请稍后重试' });
} finally {
setLoading(false);
}
};
// 获取图片URL
const getImageUrl = (item) => {
if (!item.stored_path) return '';
const filename = item.stored_path.split(/[\\/]/).pop();
const baseURL = import.meta.env.VITE_API_BASE_URL || 'http://localhost:8000';
return `${baseURL}/materials/${filename}`;
};
const loadTemplates = async () => {
const userId = user?.id;
if (!userId) {
setMessage({ type: 'error', text: '请先登录' });
return;
}
try {
const [sceneResp, copyResp] = await Promise.all([
textTemplateAPI.getTextTemplates(userId, 'prompt', 1, 100),
textTemplateAPI.getTextTemplates(userId, 'copywrite', 1, 100),
]);
const mapTemplates = (resp) => (resp && resp.success && Array.isArray(resp.templates)
? resp.templates.map(t => ({
id: t.id,
name: t.text_name,
tag: t.text_label,
created_date: t.created_at,
content: t.text_content,
}))
: []);
setSceneTemplates(mapTemplates(sceneResp));
setCopyTemplates(mapTemplates(copyResp));
} catch (error) {
console.error('加载模板失败:', error);
setMessage({ type: 'error', text: '加载模板失败,请稍后重试' });
}
};
// 轮询多个任务状态并聚合进度
const pollAllTasksStatus = async (taskIds) => {
try {
if (!taskIds || taskIds.length === 0) return;
const responses = await Promise.allSettled(taskIds.map(id => aiSwapAPI.getTaskStatus(id)));
const fulfilled = responses
.filter(r => r.status === 'fulfilled')
.map(r => r.value);
const rejected = responses.filter(r => r.status === 'rejected');
const total = taskIds.length;
const completed = fulfilled.filter(r => r.status === 'completed').length;
const failed = fulfilled.filter(r => r.status === 'failed').length + rejected.length;
const processing = fulfilled.filter(r => r.status === 'processing').length;
const pending = total - completed - failed - processing;
// 聚合进度(平均)
const avgProgress = fulfilled.length > 0
? Math.round(
fulfilled.reduce((sum, r) => sum + (typeof r.progress === 'number' ? r.progress : 0), 0) / fulfilled.length
)
: 0;
setGenerationProgress(avgProgress);
setTaskStatus({
total,
completed,
failed,
processing,
pending,
details: fulfilled
});
if (completed + failed === total) {
setCurrentTask('生成完成');
setMessage({ type: failed === 0 ? 'success' : 'error', text: failed === 0 ? '全部生成完成!' : `部分失败:成功 ${completed},失败 ${failed}` });
setIsGenerating(false);
if (pollingInterval) {
clearInterval(pollingInterval);
setPollingInterval(null);
}
} else if (processing > 0 || pending > 0) {
setCurrentTask('AI处理中...');
}
} catch (error) {
console.error('轮询任务状态失败:', error);
setMessage({ type: 'error', text: '获取任务状态失败' });
}
};
const handleSelection = (setter, selectedItems, id) => {
setter(prev => {
if (prev.includes(id)) {
return prev.filter(item => item !== id);
} else {
return [...prev, id];
}
});
};
const handleTaskTypeToggle = (taskType) => {
setSelectedTaskTypes(prev => {
if (prev.includes(taskType)) {
return prev.filter(type => type !== taskType);
} else {
return [...prev, taskType];
}
});
};
// 取消当前任务
const handleCancelTask = async () => {
if (!currentTaskIds || currentTaskIds.length === 0) return;
try {
await Promise.allSettled(currentTaskIds.map(id => aiSwapAPI.cancelTask(id)));
setMessage({ type: 'success', text: '已取消全部任务' });
setIsGenerating(false);
setCurrentTaskId(null);
setCurrentTaskIds([]);
setTaskStatus(null);
if (pollingInterval) {
clearInterval(pollingInterval);
setPollingInterval(null);
}
} catch (error) {
console.error('取消任务失败:', error);
setMessage({ type: 'error', text: '取消任务失败' });
}
};
const handleGenerate = async () => {
// 验证用户登录状态
if (!user?.id) {
setMessage({ type: 'error', text: '请先登录' });
return;
}
// 验证选择条件
if (selectedIpCharacters.length === 0) {
setMessage({ type: 'error', text: '请选择至少一个IP形象' });
return;
}
if (selectedClothingItems.length === 0) {
setMessage({ type: 'error', text: '请选择至少一件服装' });
return;
}
if (!selectedTaskTypes.includes('换脸') ||
!selectedTaskTypes.includes('换衣')) {
setMessage({ type: 'error', text: '请选择换脸和换衣任务类型' });
return;
}
if (selectedSceneTemplates.length === 0) {
setMessage({ type: 'error', text: '请选择至少一个场景模板' });
return;
}
setIsGenerating(true);
setGenerationProgress(0);
setMessage(null);
setCurrentTaskId(null);
setTaskStatus(null);
try {
// 使用当前登录用户的ID
const userId = user.id;
// 组合选择,进行笛卡尔积
const selectedFaces = ipCharacters.filter(item => selectedIpCharacters.includes(item.id));
const selectedClothes = clothingItems.filter(item => selectedClothingItems.includes(item.id));
const selectedScenes = sceneTemplates.filter(item => selectedSceneTemplates.includes(item.id));
const selectedCopies = selectedCopyTemplates.length > 0
? copyTemplates.filter(item => selectedCopyTemplates.includes(item.id))
: [null];
if (selectedFaces.length === 0 || selectedClothes.length === 0 || selectedScenes.length === 0) {
throw new Error('选中的素材/模板不存在');
}
const submittedTaskIds = [];
const totalCombos = selectedFaces.length * selectedClothes.length * selectedScenes.length * selectedCopies.length;
console.log(`即将提交组合任务数量: ${totalCombos}`);
// 顺序提交,避免瞬时过载(也可改成并发 Promise.all)
for (const face of selectedFaces) {
for (const cloth of selectedClothes) {
for (const scene of selectedScenes) {
for (const copy of selectedCopies) {
let prompt = 'AI换脸换装';
if (scene) {
prompt += `,场景:${scene.content}`;
}
if (copy) {
prompt += `,文案风格:${copy.name}`;
}
if (taskName.trim()) {
prompt += `,任务:${taskName.trim()}`;
}
const swapData = {
user_id: userId,
face_image_id: face.id,
cloth_image_id: cloth.id,
prompt,
quantity: quantityPerGroup,
};
const resp = await aiSwapAPI.processSwap(swapData);
if (resp && resp.success === true && resp.task_id) {
submittedTaskIds.push(resp.task_id);
} else {
console.warn('组合任务提交失败:', resp);
}
}
}
}
}
if (submittedTaskIds.length > 0) {
setCurrentTaskIds(submittedTaskIds);
setCurrentTaskId(submittedTaskIds[0]);
setMessage({ type: 'success', text: `已提交 ${submittedTaskIds.length}/${totalCombos} 个任务,处理中...` });
setCurrentTask('准备素材');
setGenerationProgress(5);
// 开始轮询:聚合全部任务进度
const interval = setInterval(() => {
pollAllTasksStatus(submittedTaskIds);
}, 2000);
setPollingInterval(interval);
// 立即轮询一次
pollAllTasksStatus(submittedTaskIds);
} else {
setIsGenerating(false);
setMessage({ type: 'error', text: '未能成功提交任何任务,请稍后重试' });
}
} catch (error) {
setIsGenerating(false);
console.error('换脸换装请求异常:', error);
const errorMsg = error.message || '处理失败,请稍后重试';
setMessage({ type: 'error', text: errorMsg });
}
};
const canGenerate = selectedIpCharacters.length > 0 &&
selectedClothingItems.length > 0 &&
selectedTaskTypes.includes('换脸') &&
selectedTaskTypes.includes('换衣') &&
selectedSceneTemplates.length > 0 &&
!isGenerating;
const totalSelected = selectedIpCharacters.length + selectedClothingItems.length;
return (
{/* 顶部标题栏 */}
智能生成
选择素材和模板,使用ComfyUI一键生成营销内容
{/* 消息提示 */}
{message && (
)}
{/* 主要内容区域 */}
{/* 第一栏:选择原图 */}
选择原图
{selectedOriginalImages.length}
{loading ? (
) : originalImages.length > 0 ? (
{originalImages.map((image) => (
handleSelection(setSelectedOriginalImages, selectedOriginalImages, id)}
/>
))}
) : (
)}
{/* 第二栏:选择IP形象 */}
选择IP形象
{selectedIpCharacters.length}
{loading ? (
) : ipCharacters.length > 0 ? (
{ipCharacters.map((character) => (
handleSelection(setSelectedIpCharacters, selectedIpCharacters, id)}
/>
))}
) : (
)}
{/* 第三栏:选择服装 */}
选择服装
{selectedClothingItems.length}
{loading ? (
) : clothingItems.length > 0 ? (
{clothingItems.map((clothing) => (
handleSelection(setSelectedClothingItems, selectedClothingItems, id)}
/>
))}
) : (
)}
{/* 第四栏:生成配置 */}
生成配置
{/* 任务命名 */}
{/* 任务类型 */}
{['换脸', '换衣', '换背景'].map((taskType) => (
handleTaskTypeToggle(taskType)}
className="w-4 h-4 text-blue-600 bg-gray-100 border-gray-300 rounded focus:ring-blue-500"
/>
))}
{/* 每组生成数量 */}
setQuantityPerGroup(parseInt(e.target.value) || 1)}
className="mt-2 w-full"
/>
{/* 场景模板 */}
{selectedSceneTemplates.length}
{sceneTemplates.map((template) => (
handleSelection(setSelectedSceneTemplates, selectedSceneTemplates, id)}
type="scene"
/>
))}
{/* 文案模板 */}
{selectedCopyTemplates.length}
{copyTemplates.map((template) => (
handleSelection(setSelectedCopyTemplates, selectedCopyTemplates, id)}
type="copy"
/>
))}
{/* 生成进度 */}
{isGenerating && (
生成进度
{currentTaskIds && currentTaskIds.length > 0 && (
)}
{currentTask}
{generationProgress}%
{currentTaskIds && currentTaskIds.length > 0 && (
任务数: {currentTaskIds.length}
示例任务ID: {currentTaskIds[0]}
{taskStatus && (
完成: {taskStatus.completed}
进行中: {taskStatus.processing}
等待: {taskStatus.pending}
失败: {taskStatus.failed}
)}
)}
)}
);
}