ocr_compare.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. import os, io, time, json, csv, ast, traceback, cv2, asyncio, datetime
  2. from typing import Optional
  3. from concurrent.futures import ThreadPoolExecutor
  4. import numpy as np
  5. import pandas as pd
  6. from PIL import Image
  7. from fastapi import FastAPI, File, UploadFile, Form
  8. from fastapi.responses import JSONResponse
  9. from fastapi.middleware.cors import CORSMiddleware
  10. from fastapi.staticfiles import StaticFiles
  11. from config import (
  12. ocr_images_dir, model, header, port, file_url,
  13. Search, ID, Matio, Item
  14. )
  15. from utils import get_time, detection, sql_product, image_handle, Compare
  16. detect_instance = detection()
  17. ocr = detect_instance.ocr
  18. app = FastAPI()
  19. executor = ThreadPoolExecutor(max_workers=15)
  20. camera_connections = {}
  21. app.mount(header, StaticFiles(directory=ocr_images_dir), name="static")
  22. app.add_middleware(
  23. CORSMiddleware,
  24. allow_origins=["*"],
  25. allow_credentials=True,
  26. allow_methods=["*"],
  27. allow_headers=["*"],
  28. )
  29. @app.post('/detect_barcode')
  30. async def detect_barcode(file: UploadFile = File(...)):
  31. try:
  32. # 确保目录存在
  33. tmp_dir = os.path.join(ocr_images_dir, 'tmp_images')
  34. os.makedirs(tmp_dir, exist_ok=True)
  35. # 读取并处理上传的图像
  36. contents = await file.read()
  37. upload_timestamp = datetime.datetime.now().strftime("%Y_%m_%d-%H__%M__%S")
  38. image = Image.open(io.BytesIO(contents))
  39. image = image_handle.correct_image_orientation(image)
  40. image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
  41. # 保存原始图像
  42. image_origin_path = os.path.join(tmp_dir, f"origin_images_{upload_timestamp}.png")
  43. cv2.imwrite(image_origin_path, image)
  44. time1 = time.time()
  45. # 预测并处理结果
  46. results = model.predict(image, conf=0.3)
  47. print(f"predict time is:{time.time() - time1}")
  48. for result in results:
  49. for obb in result.obb:
  50. if int(obb.cls.item()) == 15 and float(obb.conf.item()) > 0:
  51. points = obb.xyxyxyxy[0].cpu().numpy()
  52. cropped_image = image_handle.crop_image_second(image.copy(), points)
  53. barcode_image_path = os.path.join(tmp_dir, f"barcode_images_{upload_timestamp}.png")
  54. cv2.imwrite(barcode_image_path, cropped_image)
  55. barcode = await asyncio.get_event_loop().run_in_executor(
  56. executor, detect_instance.detect_barcode_ocr, barcode_image_path)
  57. if barcode:
  58. return {'barcode': barcode, 'image_origin_path': image_origin_path, 'upload_timestamp': upload_timestamp}
  59. return {'barcode': None, 'image_origin_path': image_origin_path, 'upload_timestamp': upload_timestamp}
  60. except Exception as e:
  61. print(e)
  62. return None, None, None
  63. @app.post('/get_barcode')
  64. async def get_barcode(message: Optional[str] = Form(None), file: Optional[UploadFile] = File(None)):
  65. try:
  66. result_dir = os.path.join(ocr_images_dir, 'results')
  67. os.makedirs(result_dir, exist_ok=True)
  68. if message:
  69. item_dict = json.loads(message)
  70. message_new = Matio(**item_dict)
  71. barcode, image_origin_path, upload_timestamp, barcode_type, matio_id = message_new.barcode, message_new.image_origin_path, message_new.upload_time, message_new.barcode_type, message_new.matio_id
  72. elif file:
  73. barcode_data = await detect_barcode(file)
  74. barcode, image_origin_path, upload_timestamp = barcode_data.get('barcode'), barcode_data.get('image_origin_path'), barcode_data.get('upload_timestamp')
  75. barcode_type, matio_id = Matio().barcode_type, Matio().matio_id
  76. if not barcode:
  77. return JSONResponse(content={"code": 0, "decs": "未识别到barcode,请重新上传"}, status_code=500)
  78. image = cv2.imread(image_origin_path)
  79. image_width, image_height = image.shape[1], image.shape[0]
  80. detection_timestamp = datetime.datetime.now().strftime("%Y_%m_%d-%H__%M__%S")
  81. data_result, matio_id, color_id, _ = sql_product.sql_information(barcode=barcode, barcode_type=barcode_type, matio_id=matio_id)
  82. results_dir = os.path.join(result_dir, matio_id)
  83. os.makedirs(results_dir, exist_ok=True)
  84. result_image_path = os.path.join(results_dir, f"ocr_result_images_{detection_timestamp}.png")
  85. regular_image_path = os.path.join(results_dir, f"regular_images_{upload_timestamp}.png")
  86. cv2.imwrite(regular_image_path, image)
  87. ocr_result = []
  88. result = ocr.ocr(regular_image_path, cls=True)
  89. for line in result[0]:
  90. bbox = [[int(x) for x in point] for point in line[0]]
  91. text = line[1][0]
  92. image = image_handle.draw_box_and_text(image, bbox, text)
  93. ocr_result.append([text])
  94. resize = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
  95. cv2.imwrite(result_image_path, resize)
  96. ocr_image_url = file_url + result_image_path.replace(ocr_images_dir, '/')
  97. regular_image_url = file_url + regular_image_path.replace(ocr_images_dir, '/')
  98. data_set, log = Compare.compare(ocr_result=ocr_result, dataset=data_result)
  99. color_set = {k: data_set.get(k) for k in ['color_name', 'color_id', '产品名称', 'language']}
  100. color_set['颜色'] = color_set.pop('color_name') if 'color_name' in color_set else ''
  101. color_set['色号'] = color_set.pop('color_id') if 'color_id' in color_set else ''
  102. color_set['语言'] = color_set.pop('language') if 'language' in color_set else ''
  103. name_set = {k: data_set.get(k) for k in ['号型', 'size_id', '产品名称', 'language']}
  104. if 'size_id' in name_set:
  105. name_set.pop('size_id')
  106. name_set['语言'] = name_set.pop('language') if 'language' in name_set else ''
  107. for k in ['color_name', 'color_id', '号型', 'size_id', '产品名称', 'language']:
  108. data_set.pop(k, None)
  109. size_compare_flag, size_compare_logs = sql_product.size_information(matio_id, color_id)
  110. color_compare_flag, color_compare_logs = sql_product.color_information(matio_id)
  111. csv_file_path = os.path.join(ocr_images_dir, 'history.csv')
  112. dicts = {
  113. "id": str(len(pd.read_csv(csv_file_path)) if os.path.exists(csv_file_path) else '0'),
  114. "matio_id": matio_id,
  115. "item_num": matio_id.split('-')[0],
  116. "difference": '1' if log else '0',
  117. "upload_time": upload_timestamp.replace('__', ':').replace('_', '/').replace('-', ' '),
  118. "ocr_time": detection_timestamp.replace('__', ':').replace('_', '/').replace('-', ' '),
  119. "logs": log,
  120. "regular_image": regular_image_url,
  121. "ocr_image": ocr_image_url,
  122. "size_compare_flag": size_compare_flag,
  123. "size_compare_logs": size_compare_logs if size_compare_logs else "'None'",
  124. "color_compare_flag": color_compare_flag,
  125. "color_compare_logs": color_compare_logs if color_compare_logs else "'None'",
  126. "barcode_type": barcode_type,
  127. "data_set": data_set,
  128. "color_set": color_set,
  129. "name_set": name_set
  130. }
  131. with open(csv_file_path, 'a', newline='', encoding='utf-8') as csv_file:
  132. writer = csv.DictWriter(csv_file, fieldnames=dicts.keys())
  133. if csv_file.tell() == 0:
  134. writer.writeheader()
  135. writer.writerow(dicts)
  136. return JSONResponse(content={"code": 1, "decs": "识别成功"}, status_code=200)
  137. except Exception as e:
  138. traceback.print_exc()
  139. return JSONResponse(content={"code": 0, "decs": "失败,出现错误,请重试"}, status_code=500)
  140. @app.post("/test/")
  141. async def test(file: Optional[UploadFile] = File(None),
  142. item: Optional[str] = Form(None)):
  143. if item:
  144. item_dict = json.loads(item)
  145. item_model = Item(**item_dict)
  146. return {"item": item_model.name + item_model.description, "filename": file}
  147. else:
  148. return {'status': 'succeed', "filename":file.filename}
  149. @app.post('/search')
  150. async def search_info(message: Search):
  151. try:
  152. csv_path = os.path.join(ocr_images_dir, 'history.csv')
  153. if not os.path.exists(csv_path):
  154. return {"code": "1", "decs": None, "data": None}
  155. df = pd.read_csv(csv_path)
  156. df_length = len(df)
  157. # 应用过滤条件
  158. filters = {
  159. 'barcode_type': message.barcode_type,
  160. 'matio_id': message.matio_id,
  161. 'item_num': message.item_num,
  162. 'difference': int(message.difference) if message.difference else None
  163. }
  164. for col, condition in filters.items():
  165. if condition or condition==0:
  166. df = df[df[col] == condition]
  167. # 时间过滤
  168. for time_col, start_time, end_time in [
  169. ('upload_time', message.uploadStartTime, message.uploadEndTime),
  170. ('ocr_time', message.ocrStartTime, message.ocrEndTime)
  171. ]:
  172. if start_time and end_time:
  173. df[time_col] = pd.to_datetime(df[time_col])
  174. start = datetime.datetime(**get_time(start_time))
  175. end = datetime.datetime(**get_time(end_time))
  176. df = df[(df[time_col] >= start) & (df[time_col] <= end)]
  177. df[time_col] = df[time_col].dt.strftime("%Y/%m/%d %H:%M:%S")
  178. # 排序和分页
  179. df = df.sort_values(by="upload_time", ascending=False)
  180. start = (message.pageNum - 1) * message.pageSize
  181. end = start + message.pageSize
  182. page_data = df.iloc[start:end].to_dict(orient="records")
  183. # 处理数据
  184. for record in page_data:
  185. for key, value in record.items():
  186. try:
  187. record[key] = ast.literal_eval(value)
  188. except (SyntaxError, ValueError):
  189. pass
  190. data = {
  191. "records": page_data,
  192. "total": str(df_length),
  193. "size": str(message.pageSize),
  194. "current": str(message.pageNum),
  195. "orders": [],
  196. "optimizeCountSql": True,
  197. "searchCount": True,
  198. "countId": '',
  199. "maxLimit": '',
  200. "pages": str(df_length // message.pageSize + 1)
  201. }
  202. return {"code": "1", "decs": None, "data": data}
  203. except Exception as e:
  204. traceback.print_exc()
  205. return JSONResponse(content={"code": "0", "decs": str(e)}, status_code=500)
  206. @app.post('/get_matio_id')
  207. async def get_matio_id(file: UploadFile = File(...)):
  208. try:
  209. if not file:
  210. return {"code": "1", "matio_list": []}
  211. barcode_data = await detect_barcode(file)
  212. barcode = barcode_data.get('barcode')
  213. prefix_code_list = sql_product.sql_matio_id(prefix_code=barcode)
  214. return {
  215. "code": "1",
  216. "matio_list": prefix_code_list,
  217. 'image_origin_path': barcode_data.get('image_origin_path'),
  218. 'upload_timestamp': barcode_data.get('upload_timestamp'),
  219. 'barcode': barcode
  220. }
  221. except Exception as e:
  222. return JSONResponse(content={"code": "0", "decs": str(e)}, status_code=500)
  223. @app.post('/show')
  224. async def show_info(message: ID):
  225. try:
  226. csv_path = os.path.join(ocr_images_dir, 'history.csv')
  227. df = pd.read_csv(csv_path)
  228. data = df[df['id'] == int(message.id)].iloc[0].to_dict()
  229. for key, value in data.items():
  230. try:
  231. data[key] = ast.literal_eval(value)
  232. except (ValueError, SyntaxError):
  233. if key == 'logs':
  234. data[key] = json.loads(value.replace("'", '"'))
  235. return {"code": "1", "decs": None, "data": {"records": data}}
  236. except Exception as e:
  237. return JSONResponse(content={"code": "0", "decs": str(e)}, status_code=500)
  238. @app.post('/history')
  239. async def history_info(message: ID):
  240. try:
  241. csv_path = os.path.join(ocr_images_dir, 'history.csv')
  242. df = pd.read_csv(csv_path)
  243. # 获取指定 id 对应的 matio_id
  244. matio_id = df.loc[df['id'] == int(message.id), 'matio_id'].iloc[0]
  245. # 筛选并排序数据
  246. data = df[df['matio_id'] == matio_id].sort_values(by="upload_time", ascending=False)
  247. # 格式化时间并转换为字典
  248. data['upload_time'] = pd.to_datetime(data["upload_time"]).dt.strftime("%Y/%m/%d %H:%M:%S")
  249. records = data.to_dict(orient='records')
  250. # 处理 logs 字段
  251. for record in records:
  252. record['logs'] = json.loads(record['logs'].replace("'", '"'))
  253. return {"code": "1", "decs": None, "data": {"records": records}}
  254. except Exception as e:
  255. traceback.print_exc()
  256. return JSONResponse(content={"code": "0", "decs": str(e)}, status_code=500)
  257. if __name__ == "__main__":
  258. import uvicorn
  259. uvicorn.run(app, host="0.0.0.0", port=port)