pdf_parser.py 6.06 KB
Newer Older
kihoon.lee's avatar
upload  
kihoon.lee committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import io
import os
import fitz  # PyMuPDF
from PIL import Image, UnidentifiedImageError
import logging
import pandas as pd
from typing import Union, Optional, List, Tuple
from modules.ocr import ReaderForEasyOCR
import asyncio

logger = logging.getLogger()

def prepare_inputs(path_or_content: Union[str, bytes]) -> fitz.Document:
    """
    Prepare inputs for PyMuPDF
    Args:
        path_or_content: File path or content
    Returns:
        PyMuPDF document object
    """

    if isinstance(path_or_content, str):
        # 파일 경로를 직접 전달하여 fitz.open 호출
        return fitz.open(path_or_content)

    elif isinstance(path_or_content, bytes):
        # 파일 내용을 바이트 형식으로 받아서 fitz.open 호출
        return fitz.open("pdf", path_or_content)

    raise ValueError("Invalid input type")

class PDFParser:
    def __init__(self, use_ocr: bool, ocr_reader: Optional[ReaderForEasyOCR] = None):
        self.use_ocr = use_ocr
        self.ocr_reader = ocr_reader

    async def parse(self, file_path: Union[str, bytes], file_name: Optional[str] = None) -> str:
        parsed_content = []

        doc = prepare_inputs(file_path)  # fitz.Document 객체 생성
        name = file_name if file_name else "Unknown"
        if file_name is not None:
            name = file_name
        elif isinstance(file_path, (str, os.PathLike)):  # Union 안의 타입 검사 수정
            name = os.path.basename(file_path)

        for page_number in range(len(doc)):
            elements = []
            page = doc.load_page(page_number)
            blocks = page.get_text("dict")["blocks"]

            # PyMuPDF로 테이블 추출
            tables = page.find_tables()
            added_tables = set()  # 추가된 테이블을 추적하기 위한 집합
            table_areas = [table.bbox for table in tables]  # 테이블 영역을 저장하는 리스트

            # 이미지 추출
            images = page.get_images(full=True)
            for img_index, img in enumerate(images):
                xref = img[0]
                base_image = doc.extract_image(xref)
                image_bytes = base_image["image"]
                img_ext = base_image["ext"]

                try:
                    image = Image.open(io.BytesIO(image_bytes))

                    # WMF 형식 처리 방지
                    if image.format == "WMF":
                        logger.warning(f"Skipping WMF image on page {page_number + 1} as it cannot be processed.")
                        continue

                    width, height = image.size
                    x0, y0, x1, y1 = img[3:7]
                    ocr_text = ""

                    # OCR 적용 조건: 150x150 픽셀 이상
                    if self.use_ocr and self.ocr_reader and (width >= 150 and height >= 150): 
                        # 이미지 흑백 변환
                        image = image.convert('L')
                        
                        # OCR 적용
                        ocr_results = await self.ocr_reader(image)
                        ocr_text = "\n".join([text for bbox, text in ocr_results])
                        ocr_text = f"(ocr)\n{ocr_text}\n(/ocr)"

                    image_text = f"(image)\n{ocr_text}\n(/image)"
                    elements.append((y0, 'image', image_text))

                except UnidentifiedImageError:
                    logger.error(f"Unable to identify image format for an image on page {page_number + 1}. Skipping this image.")
                    continue

            for block in blocks:
                x0, y0, x1, y1 = block['bbox']
                text = " ".join([span["text"] for line in block.get("lines", []) for span in line["spans"]])

                # is_table = False
                for table_index, table_area in enumerate(table_areas):
                    table_x0, table_y0, table_x1, table_y1 = table_area

                    if (
                        x0 >= table_x0 and y0 >= table_y0 and
                        x1 <= table_x1 and y1 <= table_y1  # 테이블 범위 안의 데이터인 경우
                    ):
                        if table_index not in added_tables:  # 테이블이 아직 추가되지 않은 경우 추가
                            table_content = self.convert_table_to_csv(tables[table_index])
                            elements.append((table_y0, 'table', table_content))
                            added_tables.add(table_index)
                        break
                else:
                    elements.append((y0, 'text', text.strip()))

            # 위치를 기준으로 요소 정렬
            elements.sort(key=lambda x: x[0])

            # 정렬된 요소들을 하나의 문자열로 결합
            page_content = "\n\n".join(element[2] for element in elements)
            entry = {
                "document_id": f"{name}@{page_number + 1:04}",
                "text": page_content
            }
            parsed_content.append(entry)

        return parsed_content

    def convert_table_to_csv(self, table):
        """Convert PyMuPDF table to CSV format"""
        data = table.extract()

        # DataFrame으로 변환하고 CSV 형식으로 변환
        df = pd.DataFrame(data)
        df = self.unmerge_cells(df)
        # logger.info(f"==========CSV Table Info==========\n{df.to_csv(index=False, header=False)}")
        return df.to_csv(index=False, header=False)

    def convert_table_to_markdown(self, table):
        """Convert PyMuPDF table to markdown table"""
        data = table.extract()

        # DataFrame으로 변환하고 마크다운 형식으로 변환
        df = pd.DataFrame(data[1:], columns=data[0])
        df = self.unmerge_cells(df)
        # logger.info(f"==========Markdown Table Info==========\n{df.to_markdown(index=False)}")
        return df.to_markdown(index=False)

    def unmerge_cells(self, df):
        # TODO: 가로/세로 병합 예외 케이스 처리 필요
        # Forward fill to handle vertical merges
        df = df.ffill(axis=0).bfill(axis=0)
        # Forward fill to handle horizontal merges
        df = df.ffill(axis=1).bfill(axis=1)
        return df