diff --git a/docling_core/transforms/serializer/azure.py b/docling_core/transforms/serializer/azure.py index 674f90b8..d4522a62 100644 --- a/docling_core/transforms/serializer/azure.py +++ b/docling_core/transforms/serializer/azure.py @@ -44,9 +44,10 @@ DocSerializer, create_ser_result, ) -from docling_core.types.doc.base import CoordOrigin -from docling_core.types.doc.document import ( +from docling_core.types.doc import ( + CoordOrigin, DocItem, + DocItemLabel, DoclingDocument, FormItem, InlineGroup, @@ -54,12 +55,12 @@ ListGroup, NodeItem, PictureItem, + ProvenanceItem, RefItem, RichTableCell, TableItem, TextItem, ) -from docling_core.types.doc.labels import DocItemLabel def _bbox_to_polygon_coords( @@ -78,7 +79,7 @@ def _bbox_to_polygon_for_item( doc: DoclingDocument, item: DocItem ) -> Optional[list[float]]: """Compute a TOPLEFT-origin polygon for the first provenance of the item.""" - if not item.prov: + if not item.prov or not isinstance(item.prov[0], ProvenanceItem): return None prov = item.prov[0] @@ -189,7 +190,7 @@ def serialize( # Lists may be represented either as TextItem(ListItem) or via groups; # we treat any TextItem as a paragraph-like entry. - if item.prov: + if item.prov and isinstance(item.prov[0], ProvenanceItem): prov = item.prov[0] page_no = prov.page_no polygon = _bbox_to_polygon_for_item(doc, item) @@ -241,7 +242,7 @@ def serialize( ) -> SerializationResult: assert isinstance(doc_serializer, AzureDocSerializer) - if not item.prov: + if not item.prov or not isinstance(item.prov[0], ProvenanceItem): return create_ser_result() prov = item.prov[0] @@ -322,7 +323,7 @@ def serialize( ) -> SerializationResult: assert isinstance(doc_serializer, AzureDocSerializer) - if not item.prov: + if not item.prov or not isinstance(item.prov[0], ProvenanceItem): return create_ser_result() prov = item.prov[0] @@ -340,7 +341,11 @@ def serialize( for foot_ref in item.footnotes: if isinstance(foot_ref, RefItem): tgt = foot_ref.resolve(doc) - if isinstance(tgt, TextItem) and tgt.prov: + if ( + isinstance(tgt, TextItem) + and tgt.prov + and isinstance(tgt.prov[0], ProvenanceItem) + ): f_poly = _bbox_to_polygon_for_item(doc, tgt) if f_poly is not None: foots.append( diff --git a/docling_core/transforms/serializer/common.py b/docling_core/transforms/serializer/common.py index b494eb0e..abec65e9 100644 --- a/docling_core/transforms/serializer/common.py +++ b/docling_core/transforms/serializer/common.py @@ -34,11 +34,11 @@ SerializationResult, Span, ) -from docling_core.types.doc.document import ( - DOCUMENT_TOKENS_EXPORT_LABELS, +from docling_core.types.doc import ( ContentLayer, DescriptionAnnotation, DocItem, + DocItemLabel, DoclingDocument, FloatingItem, Formatting, @@ -51,12 +51,13 @@ PictureDataType, PictureItem, PictureMoleculeData, + ProvenanceItem, Script, TableAnnotationType, TableItem, TextItem, ) -from docling_core.types.doc.labels import DocItemLabel +from docling_core.types.doc.document import DOCUMENT_TOKENS_EXPORT_LABELS _DEFAULT_LABELS = DOCUMENT_TOKENS_EXPORT_LABELS _DEFAULT_LAYERS = {cl for cl in ContentLayer} @@ -110,7 +111,11 @@ def _iterate_items( add_page_breaks=add_page_breaks, visited=my_visited, ): - if isinstance(it, DocItem) and it.prov: + if ( + isinstance(it, DocItem) + and it.prov + and isinstance(it.prov[0], ProvenanceItem) + ): page_no = it.prov[0].page_no if prev_page_nr is not None and page_no > prev_page_nr: yield _PageBreakNode( @@ -119,7 +124,11 @@ def _iterate_items( next_page=page_no, ), lvl break - elif isinstance(item, DocItem) and item.prov: + elif ( + isinstance(item, DocItem) + and item.prov + and isinstance(item.prov[0], ProvenanceItem) + ): page_no = item.prov[0].page_no if prev_page_nr is None or page_no > prev_page_nr: if prev_page_nr is not None: # close previous range @@ -288,7 +297,10 @@ def get_excluded_refs(self, **kwargs: Any) -> set[str]: params.pages is not None and ( (not item.prov) - or item.prov[0].page_no not in params.pages + or ( + isinstance(item.prov[0], ProvenanceItem) + and item.prov[0].page_no not in params.pages + ) ) ) ) @@ -639,6 +651,7 @@ def _get_applicable_pages(self) -> Optional[list[int]]: if ( isinstance(item, DocItem) and item.prov + and isinstance(item.prov[0], ProvenanceItem) and ( self.params.pages is None or item.prov[0].page_no in self.params.pages diff --git a/docling_core/transforms/serializer/doctags.py b/docling_core/transforms/serializer/doctags.py index beff6168..ca19efc3 100644 --- a/docling_core/transforms/serializer/doctags.py +++ b/docling_core/transforms/serializer/doctags.py @@ -26,11 +26,13 @@ _should_use_legacy_annotations, create_ser_result, ) -from docling_core.types.doc.base import BoundingBox from docling_core.types.doc.document import ( + BoundingBox, CodeItem, DocItem, + DocItemLabel, DoclingDocument, + DocumentToken, FloatingItem, FormItem, GroupItem, @@ -40,6 +42,7 @@ ListItem, NodeItem, PictureClassificationData, + PictureClassificationLabel, PictureItem, PictureMoleculeData, PictureTabularChartData, @@ -47,10 +50,9 @@ SectionHeaderItem, TableData, TableItem, + TableToken, TextItem, ) -from docling_core.types.doc.labels import DocItemLabel, PictureClassificationLabel -from docling_core.types.doc.tokens import DocumentToken, TableToken def _wrap(text: str, wrap_tag: str) -> str: @@ -365,7 +367,7 @@ def serialize( results: list[SerializationResult] = [] page_no = 1 - if len(item.prov) > 0: + if len(item.prov) > 0 and isinstance(item.prov[0], ProvenanceItem): page_no = item.prov[0].page_no if params.add_location: @@ -385,7 +387,7 @@ def serialize( for cell in item.graph.cells: cell_txt = "" - if cell.prov is not None: + if cell.prov is not None and isinstance(cell.prov, ProvenanceItem): if len(doc.pages.keys()): page_w, page_h = doc.pages[page_no].size.as_tuple() cell_txt += DocumentToken.get_location( @@ -498,7 +500,7 @@ def _get_inline_location_tags( doc_items: list[DocItem] = [] for it, _ in doc.iterate_items(root=item): if isinstance(it, DocItem): - for prov in it.prov: + for prov in (im for im in it.prov if isinstance(im, ProvenanceItem)): boxes.append(prov.bbox) doc_items.append(it) if prov is None: diff --git a/docling_core/transforms/visualizer/key_value_visualizer.py b/docling_core/transforms/visualizer/key_value_visualizer.py index b0198455..1ef12654 100644 --- a/docling_core/transforms/visualizer/key_value_visualizer.py +++ b/docling_core/transforms/visualizer/key_value_visualizer.py @@ -16,8 +16,13 @@ from typing_extensions import override from docling_core.transforms.visualizer.base import BaseVisualizer -from docling_core.types.doc.document import ContentLayer, DoclingDocument -from docling_core.types.doc.labels import GraphCellLabel, GraphLinkLabel +from docling_core.types.doc import ( + ContentLayer, + DoclingDocument, + GraphCellLabel, + GraphLinkLabel, + ProvenanceItem, +) # --------------------------------------------------------------------------- # Helper functions / constants @@ -78,7 +83,11 @@ def _draw_key_value_layer( # First draw cells (rectangles + optional labels) # ------------------------------------------------------------------ for cell in cell_dict.values(): - if cell.prov is None or cell.prov.page_no != page_no: + if ( + cell.prov is None + or not isinstance(cell.prov, ProvenanceItem) + or cell.prov.page_no != page_no + ): continue # skip cells not on this page or without bbox tl_bbox = cell.prov.bbox.to_top_left_origin( @@ -127,6 +136,8 @@ def _draw_key_value_layer( if ( src_cell.prov is None or tgt_cell.prov is None + or not isinstance(src_cell.prov, ProvenanceItem) + or not isinstance(tgt_cell.prov, ProvenanceItem) or src_cell.prov.page_no != page_no or tgt_cell.prov.page_no != page_no ): diff --git a/docling_core/transforms/visualizer/layout_visualizer.py b/docling_core/transforms/visualizer/layout_visualizer.py index 886ad8b4..8478a198 100644 --- a/docling_core/transforms/visualizer/layout_visualizer.py +++ b/docling_core/transforms/visualizer/layout_visualizer.py @@ -10,10 +10,16 @@ from typing_extensions import override from docling_core.transforms.visualizer.base import BaseVisualizer -from docling_core.types.doc import DocItemLabel -from docling_core.types.doc.base import CoordOrigin -from docling_core.types.doc.document import ContentLayer, DocItem, DoclingDocument -from docling_core.types.doc.page import BoundingRectangle, TextCell +from docling_core.types.doc import ( + BoundingRectangle, + ContentLayer, + CoordOrigin, + DocItem, + DocItemLabel, + DoclingDocument, + ProvenanceItem, + TextCell, +) class _TLBoundingRectangle(BoundingRectangle): @@ -157,7 +163,9 @@ def _draw_doc_layout( if len(elem.prov) == 0: continue # Skip elements without provenances - for prov in elem.prov: + for prov in ( + item for item in elem.prov if isinstance(item, ProvenanceItem) + ): page_nr = prov.page_no if page_nr in my_images: diff --git a/docling_core/transforms/visualizer/reading_order_visualizer.py b/docling_core/transforms/visualizer/reading_order_visualizer.py index c012f22b..0e2aa6a1 100644 --- a/docling_core/transforms/visualizer/reading_order_visualizer.py +++ b/docling_core/transforms/visualizer/reading_order_visualizer.py @@ -15,6 +15,7 @@ DocItem, DoclingDocument, PictureItem, + ProvenanceItem, ) @@ -139,7 +140,9 @@ def _draw_doc_reading_order( if len(elem.prov) == 0: continue # Skip elements without provenances - for prov in elem.prov: + for prov in ( + item for item in elem.prov if isinstance(item, ProvenanceItem) + ): page_no = prov.page_no image = my_images.get(page_no) diff --git a/docling_core/transforms/visualizer/table_visualizer.py b/docling_core/transforms/visualizer/table_visualizer.py index 0a722959..c173f33f 100644 --- a/docling_core/transforms/visualizer/table_visualizer.py +++ b/docling_core/transforms/visualizer/table_visualizer.py @@ -10,7 +10,12 @@ from typing_extensions import override from docling_core.transforms.visualizer.base import BaseVisualizer -from docling_core.types.doc.document import ContentLayer, DoclingDocument, TableItem +from docling_core.types.doc import ( + ContentLayer, + DoclingDocument, + ProvenanceItem, + TableItem, +) _log = logging.getLogger(__name__) @@ -171,12 +176,12 @@ def _draw_doc_tables( image = deepcopy(pil_img) my_images[page_nr] = image - for idx, (elem, _) in enumerate( + for _, (elem, _) in enumerate( doc.iterate_items(included_content_layers=included_content_layers) ): if not isinstance(elem, TableItem): continue - if len(elem.prov) == 0: + if len(elem.prov) == 0 or not isinstance(elem.prov[0], ProvenanceItem): continue # Skip elements without provenances if len(elem.prov) == 1: diff --git a/docling_core/types/doc/__init__.py b/docling_core/types/doc/__init__.py index 3c699f89..d8ddd0b4 100644 --- a/docling_core/types/doc/__init__.py +++ b/docling_core/types/doc/__init__.py @@ -56,11 +56,13 @@ PictureStackedBarChartData, PictureTabularChartData, ProvenanceItem, + ProvenanceTrack, RefItem, RichTableCell, Script, SectionHeaderItem, SummaryMetaField, + TableAnnotationType, TableCell, TableData, TableItem, diff --git a/docling_core/types/doc/document.py b/docling_core/types/doc/document.py index 414640d6..2e3c2630 100644 --- a/docling_core/types/doc/document.py +++ b/docling_core/types/doc/document.py @@ -35,9 +35,11 @@ AnyUrl, BaseModel, ConfigDict, + Discriminator, Field, FieldSerializationInfo, StringConstraints, + Tag, computed_field, field_serializer, field_validator, @@ -1206,11 +1208,92 @@ def from_multipage_doctags_and_images( class ProvenanceItem(BaseModel): - """ProvenanceItem.""" + """Provenance information for elements extracted from a textual document. + + A `ProvenanceItem` object acts as a lightweight pointer back into the original + document for an extracted element. It applies to documents with an explicity + or implicit layout, such as PDF, HTML, docx, or pptx. + """ + + page_no: Annotated[int, Field(description="Page number")] + bbox: Annotated[BoundingBox, Field(description="Bounding box")] + charspan: Annotated[ + tuple[int, int], Field(description="Character span (0-indexed)") + ] - page_no: int - bbox: BoundingBox - charspan: Tuple[int, int] + +class ProvenanceTrack(BaseModel): + """Provenance information for elements extracted from media assets. + + A `ProvenanceTrack` instance describes a cue in a text track associated with a + media element (audio, video, subtitles, screen recordings, ...). + """ + + start_time: Annotated[ + float, + Field( + examples=[11.0, 6.5, 5370.0], + description="Start time offset of the track cue in seconds", + ), + ] + end_time: Annotated[ + float, + Field( + examples=[12.0, 8.2, 5370.1], + description="End time offset of the track cue in seconds", + ), + ] + identifier: Optional[str] = Field( + None, + examples=["test", "123", "b72d946"], + description="An identifier of the cue", + ) + voice: Optional[str] = Field( + None, + examples=["Mary", "Fred", "Name Surname"], + description="The cue voice (speaker)", + ) + languages: Optional[list[str]] = Field( + None, + examples=[["en", "en-GB"], ["fr-CA"]], + description="Languages of the cue in BCP 47 language tag format", + ) + classes: Optional[list[str]] = Field( + None, + min_length=1, + examples=["first", "loud", "yellow"], + description="Classes for describing the cue significance", + ) + + @model_validator(mode="after") + def check_order(self) -> Self: + """Ensure start time is less than the end time.""" + if self.end_time <= self.start_time: + raise ValueError("End time must be greater than start time") + return self + + +def get_provenance_discriminator_value(v: Any) -> str: + """Callable discriminator for provenance instances. + + Args: + v: Either dict or model input. + + Returns: + A string discriminator of provenance instances. + """ + fields = {"bbox", "page_no", "charspan"} + if isinstance(v, dict): + return "item" if any(f in v for f in fields) else "track" + return "item" if any(hasattr(v, f) for f in fields) else "track" + + +ProvenanceType = Annotated[ + Union[ + Annotated[ProvenanceItem, Tag("item")], Annotated[ProvenanceTrack, Tag("track")] + ], + Discriminator(get_provenance_discriminator_value), +] class ContentLayer(str, Enum): @@ -1534,7 +1617,7 @@ class DocItem( """DocItem.""" label: DocItemLabel - prov: List[ProvenanceItem] = [] + prov: List[ProvenanceType] = [] def get_location_tokens( self, @@ -1549,7 +1632,7 @@ def get_location_tokens( return "" location = "" - for prov in self.prov: + for prov in (item for item in self.prov if isinstance(item, ProvenanceItem)): page_w, page_h = doc.pages[prov.page_no].size.as_tuple() loc_str = DocumentToken.get_location( @@ -1573,10 +1656,13 @@ def get_image( if a valid image of the page containing this DocItem is not available in doc. """ - if not len(self.prov): + if not self.prov or prov_index >= len(self.prov): + return None + prov = self.prov[prov_index] + if not isinstance(prov, ProvenanceItem): return None - page = doc.pages.get(self.prov[prov_index].page_no) + page = doc.pages.get(prov.page_no) if page is None or page.size is None or page.image is None: return None @@ -1584,9 +1670,9 @@ def get_image( if not page_image: return None crop_bbox = ( - self.prov[prov_index] - .bbox.to_top_left_origin(page_height=page.size.height) - .scale_to_size(old_size=page.size, new_size=page.image.size) + prov.bbox.to_top_left_origin(page_height=page.size.height).scale_to_size( + old_size=page.size, new_size=page.image.size + ) # .scaled(scale=page_image.height / page.size.height) ) return page_image.crop(crop_bbox.as_tuple()) @@ -2284,7 +2370,7 @@ def export_to_otsl( return "" page_no = 0 - if len(self.prov) > 0: + if len(self.prov) > 0 and isinstance(self.prov[0], ProvenanceItem): page_no = self.prov[0].page_no for i in range(nrows): @@ -2416,7 +2502,7 @@ class GraphCell(BaseModel): text: str # sanitized text orig: str # text as seen on document - prov: Optional[ProvenanceItem] = None + prov: Optional[ProvenanceType] = None # in case you have a text, table or picture item item_ref: Optional[RefItem] = None @@ -3107,7 +3193,7 @@ def add_list_item( enumerated: bool = False, marker: Optional[str] = None, orig: Optional[str] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, @@ -3118,7 +3204,7 @@ def add_list_item( :param label: str: :param text: str: :param orig: Optional[str]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ @@ -3159,7 +3245,7 @@ def add_text( label: DocItemLabel, text: str, orig: Optional[str] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, @@ -3170,7 +3256,7 @@ def add_text( :param label: str: :param text: str: :param orig: Optional[str]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ @@ -3265,7 +3351,7 @@ def add_table( self, data: TableData, caption: Optional[Union[TextItem, RefItem]] = None, # This is not cool yet. - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, label: DocItemLabel = DocItemLabel.TABLE, content_layer: Optional[ContentLayer] = None, @@ -3275,7 +3361,7 @@ def add_table( :param data: TableData: :param caption: Optional[Union[TextItem, RefItem]]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) :param label: DocItemLabel: (Default value = DocItemLabel.TABLE) @@ -3311,7 +3397,7 @@ def add_picture( annotations: Optional[List[PictureDataType]] = None, image: Optional[ImageRef] = None, caption: Optional[Union[TextItem, RefItem]] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, content_layer: Optional[ContentLayer] = None, ): @@ -3320,7 +3406,7 @@ def add_picture( :param data: Optional[List[PictureData]]: (Default value = None) :param caption: Optional[Union[TextItem: :param RefItem]]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ if not parent: @@ -3352,7 +3438,7 @@ def add_title( self, text: str, orig: Optional[str] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, @@ -3363,7 +3449,7 @@ def add_title( :param text: str: :param orig: Optional[str]: (Default value = None) :param level: LevelNumber: (Default value = 1) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ if not parent: @@ -3398,7 +3484,7 @@ def add_code( code_language: Optional[CodeLanguageLabel] = None, orig: Optional[str] = None, caption: Optional[Union[TextItem, RefItem]] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, @@ -3411,7 +3497,7 @@ def add_code( :param orig: Optional[str]: (Default value = None) :param caption: Optional[Union[TextItem: :param RefItem]]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ if not parent: @@ -3448,7 +3534,7 @@ def add_formula( self, text: str, orig: Optional[str] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, @@ -3459,7 +3545,7 @@ def add_formula( :param text: str: :param orig: Optional[str]: (Default value = None) :param level: LevelNumber: (Default value = 1) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ if not parent: @@ -3493,7 +3579,7 @@ def add_heading( text: str, orig: Optional[str] = None, level: LevelNumber = 1, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, @@ -3505,7 +3591,7 @@ def add_heading( :param text: str: :param orig: Optional[str]: (Default value = None) :param level: LevelNumber: (Default value = 1) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ if not parent: @@ -3538,13 +3624,13 @@ def add_heading( def add_key_values( self, graph: GraphData, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, ): """add_key_values. :param graph: GraphData: - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ if not parent: @@ -3569,13 +3655,13 @@ def add_key_values( def add_form( self, graph: GraphData, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, parent: Optional[NodeItem] = None, ): """add_form. :param graph: GraphData: - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param parent: Optional[NodeItem]: (Default value = None) """ if not parent: @@ -3772,7 +3858,7 @@ def insert_list_item( enumerated: bool = False, marker: Optional[str] = None, orig: Optional[str] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, hyperlink: Optional[Union[AnyUrl, Path]] = None, @@ -3785,7 +3871,7 @@ def insert_list_item( :param enumerated: bool: (Default value = False) :param marker: Optional[str]: (Default value = None) :param orig: Optional[str]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param content_layer: Optional[ContentLayer]: (Default value = None) :param formatting: Optional[Formatting]: (Default value = None) :param hyperlink: Optional[Union[AnyUrl, Path]]: (Default value = None) @@ -3846,7 +3932,7 @@ def insert_text( label: DocItemLabel, text: str, orig: Optional[str] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, hyperlink: Optional[Union[AnyUrl, Path]] = None, @@ -3858,7 +3944,7 @@ def insert_text( :param label: DocItemLabel: :param text: str: :param orig: Optional[str]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param content_layer: Optional[ContentLayer]: (Default value = None) :param formatting: Optional[Formatting]: (Default value = None) :param hyperlink: Optional[Union[AnyUrl, Path]]: (Default value = None) @@ -3958,7 +4044,7 @@ def insert_table( sibling: NodeItem, data: TableData, caption: Optional[Union[TextItem, RefItem]] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, label: DocItemLabel = DocItemLabel.TABLE, content_layer: Optional[ContentLayer] = None, annotations: Optional[list[TableAnnotationType]] = None, @@ -3969,7 +4055,7 @@ def insert_table( :param sibling: NodeItem: :param data: TableData: :param caption: Optional[Union[TextItem, RefItem]]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param label: DocItemLabel: (Default value = DocItemLabel.TABLE) :param content_layer: Optional[ContentLayer]: (Default value = None) :param annotations: Optional[List[TableAnnotationType]]: (Default value = None) @@ -4006,7 +4092,7 @@ def insert_picture( annotations: Optional[List[PictureDataType]] = None, image: Optional[ImageRef] = None, caption: Optional[Union[TextItem, RefItem]] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, content_layer: Optional[ContentLayer] = None, after: bool = True, ) -> PictureItem: @@ -4016,7 +4102,7 @@ def insert_picture( :param annotations: Optional[List[PictureDataType]]: (Default value = None) :param image: Optional[ImageRef]: (Default value = None) :param caption: Optional[Union[TextItem, RefItem]]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param content_layer: Optional[ContentLayer]: (Default value = None) :param after: bool: (Default value = True) @@ -4050,7 +4136,7 @@ def insert_title( sibling: NodeItem, text: str, orig: Optional[str] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, hyperlink: Optional[Union[AnyUrl, Path]] = None, @@ -4061,7 +4147,7 @@ def insert_title( :param sibling: NodeItem: :param text: str: :param orig: Optional[str]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param content_layer: Optional[ContentLayer]: (Default value = None) :param formatting: Optional[Formatting]: (Default value = None) :param hyperlink: Optional[Union[AnyUrl, Path]]: (Default value = None) @@ -4101,7 +4187,7 @@ def insert_code( code_language: Optional[CodeLanguageLabel] = None, orig: Optional[str] = None, caption: Optional[Union[TextItem, RefItem]] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, hyperlink: Optional[Union[AnyUrl, Path]] = None, @@ -4114,7 +4200,7 @@ def insert_code( :param code_language: Optional[str]: (Default value = None) :param orig: Optional[str]: (Default value = None) :param caption: Optional[Union[TextItem, RefItem]]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param content_layer: Optional[ContentLayer]: (Default value = None) :param formatting: Optional[Formatting]: (Default value = None) :param hyperlink: Optional[Union[AnyUrl, Path]]: (Default value = None) @@ -4156,7 +4242,7 @@ def insert_formula( sibling: NodeItem, text: str, orig: Optional[str] = None, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, hyperlink: Optional[Union[AnyUrl, Path]] = None, @@ -4167,7 +4253,7 @@ def insert_formula( :param sibling: NodeItem: :param text: str: :param orig: Optional[str]: (Default value = None) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param content_layer: Optional[ContentLayer]: (Default value = None) :param formatting: Optional[Formatting]: (Default value = None) :param hyperlink: Optional[Union[AnyUrl, Path]]: (Default value = None) @@ -4206,7 +4292,7 @@ def insert_heading( text: str, orig: Optional[str] = None, level: LevelNumber = 1, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, content_layer: Optional[ContentLayer] = None, formatting: Optional[Formatting] = None, hyperlink: Optional[Union[AnyUrl, Path]] = None, @@ -4218,7 +4304,7 @@ def insert_heading( :param text: str: :param orig: Optional[str]: (Default value = None) :param level: LevelNumber: (Default value = 1) - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param content_layer: Optional[ContentLayer]: (Default value = None) :param formatting: Optional[Formatting]: (Default value = None) :param hyperlink: Optional[Union[AnyUrl, Path]]: (Default value = None) @@ -4256,14 +4342,14 @@ def insert_key_values( self, sibling: NodeItem, graph: GraphData, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, after: bool = True, ) -> KeyValueItem: """Creates a new KeyValueItem item and inserts it into the document. :param sibling: NodeItem: :param graph: GraphData: - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param after: bool: (Default value = True) :returns: KeyValueItem: The newly created KeyValueItem item. @@ -4285,14 +4371,14 @@ def insert_form( self, sibling: NodeItem, graph: GraphData, - prov: Optional[ProvenanceItem] = None, + prov: Optional[ProvenanceType] = None, after: bool = True, ) -> FormItem: """Creates a new FormItem item and inserts it into the document. :param sibling: NodeItem: :param graph: GraphData: - :param prov: Optional[ProvenanceItem]: (Default value = None) + :param prov: Optional[ProvenanceType]: (Default value = None) :param after: bool: (Default value = True) :returns: FormItem: The newly created FormItem item. @@ -4660,7 +4746,11 @@ def _iterate_items_with_stack( not isinstance(root, DocItem) or ( page_nrs is None - or any(prov.page_no in page_nrs for prov in root.prov) + or any( + prov.page_no in page_nrs + for prov in root.prov + if isinstance(prov, ProvenanceItem) + ) ) ) and root.content_layer in my_layers @@ -4770,7 +4860,7 @@ def _with_pictures_refs( image_dir.mkdir(parents=True, exist_ok=True) if image_dir.is_dir(): - for item, level in result.iterate_items(page_no=page_no, with_groups=False): + for item, _ in result.iterate_items(page_no=page_no, with_groups=False): if isinstance(item, PictureItem): img = item.get_image(doc=self) if img is not None: @@ -4790,12 +4880,15 @@ def _with_pictures_refs( else: obj_path = loc_path - if item.image is None: + if item.image is None and isinstance( + item.prov[0], ProvenanceItem + ): scale = img.size[0] / item.prov[0].bbox.width item.image = ImageRef.from_pil( image=img, dpi=round(72 * scale) ) - item.image.uri = Path(obj_path) + elif item.image is not None: + item.image.uri = Path(obj_path) # if item.image._pil is not None: # item.image._pil.close() @@ -6274,7 +6367,11 @@ def index( if isinstance(new_item, DocItem): # update page numbers # NOTE other prov sources (e.g. GraphCell) currently not covered - for prov in new_item.prov: + for prov in ( + item + for item in new_item.prov + if isinstance(item, ProvenanceItem) + ): prov.page_no += page_delta if item.parent: diff --git a/docling_core/types/doc/webvtt.py b/docling_core/types/doc/webvtt.py new file mode 100644 index 00000000..ea202284 --- /dev/null +++ b/docling_core/types/doc/webvtt.py @@ -0,0 +1,599 @@ +"""Models for the Docling's adoption of Web Video Text Tracks format.""" + +import logging +import re +from enum import Enum +from typing import Annotated, ClassVar, Iterator, Literal, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic.types import StringConstraints +from typing_extensions import Self, override + +_log = logging.getLogger(__name__) + + +_VALID_ENTITIES: set = {"amp", "lt", "gt", "lrm", "rlm", "nbsp"} +_ENTITY_PATTERN: re.Pattern = re.compile(r"&([a-zA-Z0-9]+);") +_START_TAG_NAMES = Literal["c", "b", "i", "u", "v", "lang"] + + +class WebVTTLineTerminator(str, Enum): + """WebVTT line terminator.""" + + CRLF = "\r\n" + LF = "\n" + CR = "\r" + + +WebVTTCueIdentifier = Annotated[ + str, StringConstraints(strict=True, pattern=r"^(?!.*-->)[^\n\r]+$") +] + + +class WebVTTTimestamp(BaseModel): + """WebVTT timestamp. + + The timestamp is a string consisting of the following components in the given order: + + - hours (optional, required if non-zero): two or more digits + - minutes: two digits between 0 and 59 + - a colon character (:) + - seconds: two digits between 0 and 59 + - a full stop character (.) + - thousandths of a second: three digits + + A WebVTT timestamp is always interpreted relative to the current playback position + of the media data that the WebVTT file is to be synchronized with. + """ + + model_config = ConfigDict(regex_engine="python-re") + + raw: Annotated[ + str, + Field( + description="A representation of the WebVTT Timestamp as a single string" + ), + ] + + _pattern: ClassVar[re.Pattern] = re.compile( + r"^(?:(\d{2,}):)?([0-5]\d):([0-5]\d)\.(\d{3})$" + ) + _hours: int + _minutes: int + _seconds: int + _millis: int + + @model_validator(mode="after") + def validate_raw(self) -> Self: + """Validate the WebVTT timestamp as a string.""" + m = self._pattern.match(self.raw) + if not m: + raise ValueError(f"Invalid WebVTT timestamp format: {self.raw}") + self._hours = int(m.group(1)) if m.group(1) else 0 + self._minutes = int(m.group(2)) + self._seconds = int(m.group(3)) + self._millis = int(m.group(4)) + + if self._minutes < 0 or self._minutes > 59: + raise ValueError("Minutes must be between 0 and 59") + if self._seconds < 0 or self._seconds > 59: + raise ValueError("Seconds must be between 0 and 59") + + return self + + @property + def seconds(self) -> float: + """A representation of the WebVTT Timestamp in seconds.""" + return ( + self._hours * 3600 + + self._minutes * 60 + + self._seconds + + self._millis / 1000.0 + ) + + @override + def __str__(self) -> str: + """Return a string representation of a WebVTT timestamp.""" + return self.raw + + +class WebVTTCueTimings(BaseModel): + """WebVTT cue timings.""" + + start: Annotated[WebVTTTimestamp, Field(description="Start time offset of the cue")] + end: Annotated[WebVTTTimestamp, Field(description="End time offset of the cue")] + + @model_validator(mode="after") + def check_order(self) -> Self: + """Ensure start timestamp is less than end timestamp.""" + if self.start and self.end: + if self.end.seconds <= self.start.seconds: + raise ValueError("End timestamp must be greater than start timestamp") + return self + + @override + def __str__(self): + """Return a string representation of the cue timings.""" + return f"{self.start} --> {self.end}" + + +class WebVTTCueTextSpan(BaseModel): + """WebVTT cue text span.""" + + kind: Literal["text"] = "text" + text: Annotated[str, Field(description="The cue text.")] + + @field_validator("text", mode="after") + @classmethod + def is_valid_text(cls, value: str) -> str: + """Ensure cue text contains only permitted characters and HTML entities.""" + for match in _ENTITY_PATTERN.finditer(value): + entity = match.group(1) + if entity not in _VALID_ENTITIES: + raise ValueError( + f"Cue text contains an invalid HTML entity: &{entity};" + ) + if "&" in re.sub(_ENTITY_PATTERN, "", value): + raise ValueError("Found '&' not part of a valid entity in the cue text") + if any(ch in value for ch in {"\n", "\r", "<"}): + raise ValueError("Cue text contains invalid characters") + if len(value) == 0: + raise ValueError("Cue text cannot be empty") + + return value + + @override + def __str__(self): + """Return a string representation of the cue text span.""" + return self.text + + +class WebVTTCueComponentWithTerminator(BaseModel): + """WebVTT caption or subtitle cue component optionally with a line terminator.""" + + component: "WebVTTCueComponent" + terminator: Optional[WebVTTLineTerminator] = None + + @override + def __str__(self): + """Return a string representation of the cue component with terminator.""" + return f"{self.component}{self.terminator.value if self.terminator else ''}" + + +class WebVTTCueInternalText(BaseModel): + """WebVTT cue internal text.""" + + terminator: Optional[WebVTTLineTerminator] = None + components: Annotated[ + list[WebVTTCueComponentWithTerminator], + Field( + description=( + "WebVTT caption or subtitle cue components representing the " + "cue internal text" + ) + ), + ] = [] + + @override + def __str__(self): + """Return a string representation of the cue internal text.""" + cue_str = ( + f"{self.terminator.value if self.terminator else ''}" + f"{''.join(str(span) for span in self.components)}" + ) + return cue_str + + +class WebVTTCueSpanStartTag(BaseModel): + """WebVTT cue span start tag.""" + + name: Annotated[_START_TAG_NAMES, Field(description="The tag name")] + classes: Annotated[ + list[str], + Field(description="List of classes representing the cue span's significance"), + ] = [] + + @field_validator("classes", mode="after") + @classmethod + def validate_classes(cls, value: list[str]) -> list[str]: + """Validate cue span start tag classes.""" + for item in value: + if any(ch in item for ch in {"\t", "\n", "\r", " ", "&", "<", ">", "."}): + raise ValueError( + "A cue span start tag class contains invalid characters" + ) + if not item: + raise ValueError("Cue span start tag classes cannot be empty") + return value + + def _get_name_with_classes(self) -> str: + """Return the name of the cue span start tag with classes.""" + return f"{self.name}.{'.'.join(self.classes)}" if self.classes else self.name + + @override + def __str__(self): + """Return a string representation of the cue span start tag.""" + return f"<{self._get_name_with_classes()}>" + + +class WebVTTCueSpanStartTagAnnotated(WebVTTCueSpanStartTag): + """WebVTT cue span start tag requiring an annotation.""" + + annotation: Annotated[str, Field(description="Cue span start tag annotation")] + + @field_validator("annotation", mode="after") + @classmethod + def is_valid_annotation(cls, value: str) -> str: + """Ensure annotation contains only permitted characters and HTML entities.""" + for match in _ENTITY_PATTERN.finditer(value): + entity = match.group(1) + if entity not in _VALID_ENTITIES: + raise ValueError( + f"Annotation contains an invalid HTML entity: &{entity};" + ) + if "&" in re.sub(_ENTITY_PATTERN, "", value): + raise ValueError("Found '&' not part of a valid entity in annotation") + if any(ch in value for ch in {"\n", "\r", ">"}): + raise ValueError("Annotation contains invalid characters") + if len(value) == 0: + raise ValueError("Annotation cannot be empty") + + return value + + @override + def __str__(self): + """Return a string representation of the cue span start tag.""" + return f"<{self._get_name_with_classes()} {self.annotation}>" + + +class WebVTTCueLanguageSpanStartTag(WebVTTCueSpanStartTagAnnotated): + """WebVTT cue language span start tag.""" + + _pattern: ClassVar[re.Pattern] = re.compile( + r"^[a-zA-Z]{2,3}(-[a-zA-Z0-9]{2,8})*$", re.IGNORECASE + ) + + name: Literal["lang"] = Field("lang", description="The tag name") + + @field_validator("annotation", mode="after") + @classmethod + @override + def is_valid_annotation(cls, value: str) -> str: + """Ensure that the language annotation is in BCP 47 language tag format.""" + if cls._pattern.match(value): + return value + else: + raise ValueError("Annotation should be in BCP 47 language tag format") + + +class WebVTTCueComponentBase(BaseModel): + """WebVTT caption or subtitle cue component. + + All the WebVTT caption or subtitle cue components are represented by this class + except the WebVTT cue text span, which requires different definitions. + """ + + kind: Literal["c", "b", "i", "u", "v", "lang"] + start_tag: WebVTTCueSpanStartTag + internal_text: WebVTTCueInternalText + + @model_validator(mode="after") + def check_tag_names_match(self) -> Self: + """Ensure that the start tag name matches this cue component type.""" + if self.kind != self.start_tag.name: + raise ValueError("The tag name of this cue component should be {self.kind}") + return self + + @override + def __str__(self): + """Return a string representation of the cue component.""" + return f"{self.start_tag}{self.internal_text}" + + +class WebVTTCueVoiceSpan(WebVTTCueComponentBase): + """WebVTT cue voice span associated with a specific voice.""" + + kind: Literal["v"] = "v" + start_tag: WebVTTCueSpanStartTagAnnotated + + +class WebVTTCueClassSpan(WebVTTCueComponentBase): + """WebVTT cue class span. + + It represents a span of text and it is used to annotate parts of the cue with + applicable classes without implying further meaning (such as italics or bold). + """ + + kind: Literal["c"] = "c" + start_tag: WebVTTCueSpanStartTag = WebVTTCueSpanStartTag(name="c") + + +class WebVTTCueItalicSpan(WebVTTCueComponentBase): + """WebVTT cue italic span representing a span of italic text.""" + + kind: Literal["i"] = "i" + start_tag: WebVTTCueSpanStartTag = WebVTTCueSpanStartTag(name="i") + + +class WebVTTCueBoldSpan(WebVTTCueComponentBase): + """WebVTT cue bold span representing a span of bold text.""" + + kind: Literal["b"] = "b" + start_tag: WebVTTCueSpanStartTag = WebVTTCueSpanStartTag(name="b") + + +class WebVTTCueUnderlineSpan(WebVTTCueComponentBase): + """WebVTT cue underline span representing a span of underline text.""" + + kind: Literal["u"] = "u" + start_tag: WebVTTCueSpanStartTag = WebVTTCueSpanStartTag(name="u") + + +class WebVTTCueLanguageSpan(WebVTTCueComponentBase): + """WebVTT cue language span. + + It represents a span of text and it is used to annotate parts of the cue where the + applicable language might be different than the surrounding text's, without + implying further meaning (such as italics or bold). + """ + + kind: Literal["lang"] = "lang" + start_tag: WebVTTCueLanguageSpanStartTag + + +WebVTTCueComponent = Annotated[ + Union[ + WebVTTCueTextSpan, + WebVTTCueClassSpan, + WebVTTCueItalicSpan, + WebVTTCueBoldSpan, + WebVTTCueUnderlineSpan, + WebVTTCueVoiceSpan, + WebVTTCueLanguageSpan, + ], + Field( + discriminator="kind", + description="The type of WebVTT caption or subtitle cue component.", + ), +] + + +class WebVTTCueBlock(BaseModel): + """Model representing a WebVTT cue block. + + The optional WebVTT cue settings list is not supported. + The cue payload is limited to the following spans: text, class, italic, bold, + underline, and voice. + """ + + model_config = ConfigDict(regex_engine="python-re") + + identifier: Optional[WebVTTCueIdentifier] = Field( + None, description="The WebVTT cue identifier" + ) + timings: Annotated[WebVTTCueTimings, Field(description="The WebVTT cue timings")] + payload: Annotated[ + list[WebVTTCueComponentWithTerminator], + Field(description="The WebVTT caption or subtitle cue text"), + ] + + # pattern of a WebVTT cue span start/end tag + _pattern_tag: ClassVar[re.Pattern] = re.compile( + r"<(?P/?)" + r"(?Pi|b|c|u|v|lang)" + r"(?P(?:\.[^\t\n\r &<>.]+)*)" + r"(?:[ \t](?P[^\n\r&>]*))?>" + ) + + @field_validator("payload", mode="after") + @classmethod + def validate_payload(cls, payload): + """Ensure that the cue payload contains valid text.""" + for voice in payload: + if "-->" in str(voice): + raise ValueError("Cue payload must not contain '-->'") + return payload + + @staticmethod + def _create_text_components( + text: str, + ) -> Iterator[WebVTTCueComponentWithTerminator]: + text_list = text.split("\n") + for idx, line in enumerate(text.split("\n")): + terminator = ( + WebVTTLineTerminator.LF + if idx < len(text_list) - 1 or text.endswith("\n") + else None + ) + if len(line) > 0: + yield WebVTTCueComponentWithTerminator( + component=WebVTTCueTextSpan(text=line), + terminator=terminator, + ) + + @classmethod + def parse(cls, raw: str) -> "WebVTTCueBlock": + """Parse a WebVTT cue block from a string. + + Args: + raw: The raw WebVTT cue block string. + + Returns: + The parsed WebVTT cue block. + """ + lines = raw.strip().splitlines() + if not lines: + raise ValueError("Cue block must have at least one line") + identifier: Optional[WebVTTCueIdentifier] = None + timing_line = lines[0] + if "-->" not in timing_line and len(lines) > 1: + identifier = timing_line + timing_line = lines[1] + cue_lines = lines[2:] + else: + cue_lines = lines[1:] + + if "-->" not in timing_line: + raise ValueError("Cue block must contain WebVTT cue timings") + + start, end = [t.strip() for t in timing_line.split("-->")] + end = re.split(" |\t", end)[0] # ignore the cue settings list + timings: WebVTTCueTimings = WebVTTCueTimings( + start=WebVTTTimestamp(raw=start), end=WebVTTTimestamp(raw=end) + ) + cue_text = "\n".join(cue_lines).strip() + # adding close tag for cue spans without end tag + for omm in {"v"}: + if cue_text.startswith(f"<{omm}") and f"" not in cue_text: + cue_text += f"" + break + + stack: list[list[WebVTTCueComponentWithTerminator]] = [[]] + tag_stack: list[dict] = [] + + pos = 0 + matches = list(cls._pattern_tag.finditer(cue_text)) + i = 0 + while i < len(matches): + match = matches[i] + if match.start() > pos: + text = cue_text[pos : match.start()] + stack[-1].extend(cls._create_text_components(text)) + gps = {k: (v if v else None) for k, v in match.groupdict().items()} + + if gps["tag"] in {"c", "b", "i", "u", "v", "lang"}: + if not gps["end"]: + tag_stack.append(gps) + stack.append([]) + else: + children = stack.pop() if stack else [] + if tag_stack: + closed = tag_stack.pop() + if (ct := closed["tag"]) != gps["tag"]: + raise ValueError(f"Incorrect end tag: {ct}") + class_string = closed["class"] + annotation = closed["annotation"] + classes: list[str] = [] + if class_string: + classes = [c for c in class_string.split(".") if c] + st: WebVTTCueSpanStartTag + if annotation and ct == "lang": + st = WebVTTCueLanguageSpanStartTag( + name=ct, classes=classes, annotation=annotation.strip() + ) + elif annotation: + st = WebVTTCueSpanStartTagAnnotated( + name=ct, classes=classes, annotation=annotation.strip() + ) + else: + st = WebVTTCueSpanStartTag(name=ct, classes=classes) + it = WebVTTCueInternalText(components=children) + cp: WebVTTCueComponent + if ct == "c": + cp = WebVTTCueClassSpan(start_tag=st, internal_text=it) + elif ct == "b": + cp = WebVTTCueBoldSpan(start_tag=st, internal_text=it) + elif ct == "i": + cp = WebVTTCueItalicSpan(start_tag=st, internal_text=it) + elif ct == "u": + cp = WebVTTCueUnderlineSpan(start_tag=st, internal_text=it) + elif ct == "lang": + cp = WebVTTCueLanguageSpan(start_tag=st, internal_text=it) + elif ct == "v": + cp = WebVTTCueVoiceSpan(start_tag=st, internal_text=it) + stack[-1].append(WebVTTCueComponentWithTerminator(component=cp)) + + pos = match.end() + i += 1 + + if pos < len(cue_text): + text = cue_text[pos:] + stack[-1].extend(cls._create_text_components(text)) + + return cls( + identifier=identifier, + timings=timings, + payload=stack[0], + ) + + def __str__(self): + """Return a string representation of the WebVTT cue block.""" + parts = [] + if self.identifier: + parts.append(f"{self.identifier}\n") + timings_line = str(self.timings) + parts.append(timings_line + "\n") + for idx, span in enumerate(self.payload): + if idx == 0 and len(self.payload) == 1 and span.component.kind == "v": + # the end tag may be omitted for brevity + parts.append(str(span).removesuffix("")) + else: + parts.append(str(span)) + + return "".join(parts) + "\n" + + +class WebVTTFile(BaseModel): + """A model representing a WebVTT file.""" + + cue_blocks: list[WebVTTCueBlock] + + @staticmethod + def verify_signature(content: str) -> bool: + """Verify the WebVTT file signature.""" + if not content: + return False + elif len(content) == 6: + return content == "WEBVTT" + elif len(content) > 6 and content.startswith("WEBVTT"): + return content[6] in (" ", "\t", "\n") + else: + return False + + @classmethod + def parse(cls, raw: str) -> "WebVTTFile": + """Parse a WebVTT file. + + Args: + raw: The raw WebVTT file content. + + Returns: + The parsed WebVTT file. + """ + # Normalize newlines to LF + raw = raw.replace("\r\n", "\n").replace("\r", "\n") + + # Check WebVTT signature + if not cls.verify_signature(raw): + raise ValueError("Invalid WebVTT file signature") + + # Strip "WEBVTT" header line + lines = raw.split("\n", 1) + body = lines[1] if len(lines) > 1 else "" + + # Remove NOTE/STYLE/REGION blocks + body = re.sub(r"^(NOTE[^\n]*\n(?:.+\n)*?)\n", "", body, flags=re.MULTILINE) + body = re.sub(r"^(STYLE|REGION)(?:.+\n)*?\n", "", body, flags=re.MULTILINE) + + # Split into cue blocks + raw_blocks = re.split(r"\n\s*\n", body.strip()) + cues: list[WebVTTCueBlock] = [] + for block in raw_blocks: + try: + cues.append(WebVTTCueBlock.parse(block)) + except ValueError as e: + _log.warning(f"Failed to parse cue block:\n{block}\n{e}") + + return cls(cue_blocks=cues) + + def __iter__(self): + """Return an iterator over the cue blocks.""" + return iter(self.cue_blocks) + + def __getitem__(self, idx): + """Return the cue block at the given index.""" + return self.cue_blocks[idx] + + def __len__(self): + """Return the number of cue blocks.""" + return len(self.cue_blocks) diff --git a/docling_core/utils/legacy.py b/docling_core/utils/legacy.py index 6f8fdf99..b3b21364 100644 --- a/docling_core/utils/legacy.py +++ b/docling_core/utils/legacy.py @@ -7,20 +7,23 @@ from docling_core.types.doc import ( BoundingBox, + ContentLayer, CoordOrigin, DocItem, DocItemLabel, DoclingDocument, DocumentOrigin, + GroupItem, + ListItem, PictureItem, ProvenanceItem, SectionHeaderItem, Size, TableCell, + TableData, TableItem, TextItem, ) -from docling_core.types.doc.document import ContentLayer, GroupItem, ListItem, TableData from docling_core.types.legacy_doc.base import ( BaseCell, BaseText, @@ -164,6 +167,7 @@ def docling_document_to_legacy(doc: DoclingDocument, fallback_filaname: str = "f span=[0, len(item.text)], ) for p in item.prov + if isinstance(p, ProvenanceItem) ] main_text.append( BaseText( @@ -287,6 +291,7 @@ def _make_spans(cell: TableCell, table_item: TableItem): span=[0, 0], ) for p in item.prov + if isinstance(p, ProvenanceItem) ], ) ) @@ -314,6 +319,7 @@ def _make_spans(cell: TableCell, table_item: TableItem): span=[0, len(caption)], ) for p in item.prov + if isinstance(p, ProvenanceItem) ], obj_type=doc_item_label_to_legacy_type(item.label), text=caption, diff --git a/docs/DoclingDocument.json b/docs/DoclingDocument.json index 365a62bf..15c9e24e 100644 --- a/docs/DoclingDocument.json +++ b/docs/DoclingDocument.json @@ -233,7 +233,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -606,7 +613,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -740,7 +754,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -812,13 +833,21 @@ "prov": { "anyOf": [ { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, { "type": "null" } ], - "default": null + "default": null, + "title": "Prov" }, "item_ref": { "anyOf": [ @@ -1137,7 +1166,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -1301,7 +1337,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -1669,7 +1712,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -2054,16 +2104,19 @@ "type": "object" }, "ProvenanceItem": { - "description": "ProvenanceItem.", + "description": "Provenance information for elements extracted from a textual document.\n\nA `ProvenanceItem` object acts as a lightweight pointer back into the original\ndocument for an extracted element. It applies to documents with an explicity\nor implicit layout, such as PDF, HTML, docx, or pptx.", "properties": { "page_no": { + "description": "Page number", "title": "Page No", "type": "integer" }, "bbox": { - "$ref": "#/$defs/BoundingBox" + "$ref": "#/$defs/BoundingBox", + "description": "Bounding box" }, "charspan": { + "description": "Character span (0-indexed)", "maxItems": 2, "minItems": 2, "prefixItems": [ @@ -2086,6 +2139,120 @@ "title": "ProvenanceItem", "type": "object" }, + "ProvenanceTrack": { + "description": "Provenance information for elements extracted from media assets.\n\nA `ProvenanceTrack` instance describes a cue in a text track associated with a\nmedia element (audio, video, subtitles, screen recordings, ...).", + "properties": { + "start_time": { + "description": "Start time offset of the track cue in seconds", + "examples": [ + 11.0, + 6.5, + 5370.0 + ], + "title": "Start Time", + "type": "number" + }, + "end_time": { + "description": "End time offset of the track cue in seconds", + "examples": [ + 12.0, + 8.2, + 5370.1 + ], + "title": "End Time", + "type": "number" + }, + "identifier": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "An identifier of the cue", + "examples": [ + "test", + "123", + "b72d946" + ], + "title": "Identifier" + }, + "voice": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "The cue voice (speaker)", + "examples": [ + "Mary", + "Fred", + "Name Surname" + ], + "title": "Voice" + }, + "languages": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Languages of the cue in BCP 47 language tag format", + "examples": [ + [ + "en", + "en-GB" + ], + [ + "fr-CA" + ] + ], + "title": "Languages" + }, + "classes": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "minItems": 1, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Classes for describing the cue significance", + "examples": [ + "first", + "loud", + "yellow" + ], + "title": "Classes" + } + }, + "required": [ + "start_time", + "end_time" + ], + "title": "ProvenanceTrack", + "type": "object" + }, "RefItem": { "description": "RefItem.", "properties": { @@ -2242,7 +2409,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -2529,7 +2703,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -2726,7 +2907,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" @@ -2830,7 +3018,14 @@ "prov": { "default": [], "items": { - "$ref": "#/$defs/ProvenanceItem" + "oneOf": [ + { + "$ref": "#/$defs/ProvenanceItem" + }, + { + "$ref": "#/$defs/ProvenanceTrack" + } + ] }, "title": "Prov", "type": "array" diff --git a/test/data/webvtt/webvtt_example_01.vtt b/test/data/webvtt/webvtt_example_01.vtt new file mode 100644 index 00000000..333ca4a8 --- /dev/null +++ b/test/data/webvtt/webvtt_example_01.vtt @@ -0,0 +1,42 @@ +WEBVTT + +NOTE Copyright © 2019 World Wide Web Consortium. https://www.w3.org/TR/webvtt1/ + +00:11.000 --> 00:13.000 +We are in New York City + +00:13.000 --> 00:16.000 +We’re actually at the Lucern Hotel, just down the street + +00:16.000 --> 00:18.000 +from the American Museum of Natural History + +00:18.000 --> 00:20.000 +And with me is Neil deGrasse Tyson + +00:20.000 --> 00:22.000 +Astrophysicist, Director of the Hayden Planetarium + +00:22.000 --> 00:24.000 +at the AMNH. + +00:24.000 --> 00:26.000 +Thank you for walking down here. + +00:27.000 --> 00:30.000 +And I want to do a follow-up on the last conversation we did. + +00:30.000 --> 00:31.500 align:right size:50% +When we e-mailed— + +00:30.500 --> 00:32.500 align:left size:50% +Didn’t we talk about enough in that conversation? + +00:32.000 --> 00:35.500 align:right size:50% +No! No no no no; 'cos 'cos obviously 'cos + +00:32.500 --> 00:33.500 align:left size:50% +Laughs + +00:35.500 --> 00:38.000 +You know I’m so excited my glasses are falling off here. diff --git a/test/data/webvtt/webvtt_example_02.vtt b/test/data/webvtt/webvtt_example_02.vtt new file mode 100644 index 00000000..1152a1e8 --- /dev/null +++ b/test/data/webvtt/webvtt_example_02.vtt @@ -0,0 +1,15 @@ +WEBVTT + +NOTE Copyright © 2019 World Wide Web Consortium. https://www.w3.org/TR/webvtt1/ + +00:00.000 --> 00:02.000 +It’s a blue apple tree! + +00:02.000 --> 00:04.000 +No way! + +00:04.000 --> 00:06.000 +Hee! laughter + +00:06.000 --> 00:08.000 +That’s awesome! \ No newline at end of file diff --git a/test/data/webvtt/webvtt_example_03.vtt b/test/data/webvtt/webvtt_example_03.vtt new file mode 100644 index 00000000..a4dc1291 --- /dev/null +++ b/test/data/webvtt/webvtt_example_03.vtt @@ -0,0 +1,57 @@ +WEBVTT + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/15-0 +00:00:04.963 --> 00:00:08.571 +OK, +I think now we should be recording + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/15-1 +00:00:08.571 --> 00:00:09.403 +properly. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/16-0 +00:00:10.683 --> 00:00:11.563 +Good. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/17-0 +00:00:13.363 --> 00:00:13.803 +Yeah. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/78-0 +00:00:49.603 --> 00:00:53.363 +I was also thinking. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/113-0 +00:00:54.963 --> 00:01:02.072 +Would be maybe good to create items, + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/113-1 +00:01:02.072 --> 00:01:06.811 +some metadata, +some options that can be specific. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/150-0 +00:01:10.243 --> 00:01:13.014 +Yeah, +I mean I think you went even more than + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/119-0 +00:01:10.563 --> 00:01:12.643 +But we preserved the atoms. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/150-1 +00:01:13.014 --> 00:01:15.907 +than me. +I just opened the format. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/197-1 +00:01:50.222 --> 00:01:51.643 +give it a try, yeah. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/200-0 +00:01:52.043 --> 00:01:55.043 +Okay, talk to you later. + +62357a1d-d250-41d5-a1cf-6cc0eeceffcc/202-0 +00:01:54.603 --> 00:01:55.283 +See you. \ No newline at end of file diff --git a/test/data/webvtt/webvtt_example_04.vtt b/test/data/webvtt/webvtt_example_04.vtt new file mode 100644 index 00000000..91be3530 --- /dev/null +++ b/test/data/webvtt/webvtt_example_04.vtt @@ -0,0 +1,13 @@ +WEBVTT + +NOTE Copyright © 2019 World Wide Web Consortium. https://www.w3.org/TR/webvtt1/ + +00:01.000 --> 00:04.000 +Never drink liquid nitrogen. + +NOTE I’m not sure the timing is right on the following cue. + +00:05.000 --> 00:09.000 +— It will perforate your stomach. +— You could die. +This is true. \ No newline at end of file diff --git a/test/test_doc_base.py b/test/test_doc_base.py index 709e2eac..2d1ce498 100644 --- a/test/test_doc_base.py +++ b/test/test_doc_base.py @@ -1,6 +1,7 @@ import pytest from pydantic import ValidationError +from docling_core.types.doc import ProvenanceTrack from docling_core.types.legacy_doc.base import Prov, S3Reference @@ -37,3 +38,40 @@ def test_prov(): with pytest.raises(ValidationError, match="at least 2 items"): prov["span"] = [0] Prov(**prov) + + +def test_prov_track(): + """Test the class ProvenanceTrack.""" + + valid_track = ProvenanceTrack( + start_time=11.0, + end_time=12.0, + identifier="test", + voice="Mary", + languages=["en", "en-GB"], + classes=["v.first.loud", "i.foreignphrase"], + ) + + assert valid_track + assert valid_track.start_time == 11.0 + assert valid_track.end_time == 12.0 + assert valid_track.identifier == "test" + assert valid_track.voice == "Mary" + assert valid_track.languages == ["en", "en-GB"] + assert valid_track.classes == ["v.first.loud", "i.foreignphrase"] + + with pytest.raises(ValidationError, match="end_time"): + ProvenanceTrack(start_time=11.0) + + with pytest.raises(ValidationError, match="should be a valid list"): + ProvenanceTrack( + start_time=11.0, + end_time=12.0, + languages="en", + ) + + with pytest.raises(ValidationError, match="must be greater than start"): + ProvenanceTrack( + start_time=11.0, + end_time=11.0, + ) diff --git a/test/test_webvtt.py b/test/test_webvtt.py new file mode 100644 index 00000000..1bf9edb8 --- /dev/null +++ b/test/test_webvtt.py @@ -0,0 +1,298 @@ +"""Test the data model for WebVTT files. + +Assisted by watsonx Code Assistant. +Examples extracted from https://www.w3.org/TR/webvtt1/ +Copyright © 2019 World Wide Web Consortium. +""" + +import pytest +from pydantic import ValidationError + +from docling_core.types.doc.webvtt import ( + WebVTTCueBlock, + WebVTTCueComponentWithTerminator, + WebVTTCueInternalText, + WebVTTCueItalicSpan, + WebVTTCueLanguageSpan, + WebVTTCueLanguageSpanStartTag, + WebVTTCueSpanStartTagAnnotated, + WebVTTCueTextSpan, + WebVTTCueTimings, + WebVTTCueVoiceSpan, + WebVTTFile, + WebVTTTimestamp, +) + +from .test_data_gen_flag import GEN_TEST_DATA + +GENERATE = GEN_TEST_DATA + + +def test_vtt_cue_commponents() -> None: + """Test WebVTT components.""" + valid_timestamps = [ + "00:01:02.345", + "12:34:56.789", + "02:34.567", + "00:00:00.000", + ] + valid_total_seconds = [ + 1 * 60 + 2.345, + 12 * 3600 + 34 * 60 + 56.789, + 2 * 60 + 34.567, + 0.0, + ] + for idx, ts in enumerate(valid_timestamps): + model = WebVTTTimestamp(raw=ts) + assert model.seconds == valid_total_seconds[idx] + + """Test invalid WebVTT timestamps.""" + invalid_timestamps = [ + "00:60:02.345", # minutes > 59 + "00:01:60.345", # seconds > 59 + "00:01:02.1000", # milliseconds > 999 + "01:02:03", # missing milliseconds + "01:02", # missing milliseconds + ":01:02.345", # extra : for missing hours + "abc:01:02.345", # invalid format + ] + for ts in invalid_timestamps: + with pytest.raises(ValidationError): + WebVTTTimestamp(raw=ts) + + """Test the timestamp __str__ method.""" + model = WebVTTTimestamp(raw="00:01:02.345") + assert str(model) == "00:01:02.345" + + """Test valid cue timings.""" + start = WebVTTTimestamp(raw="00:10.005") + end = WebVTTTimestamp(raw="00:14.007") + cue_timings = WebVTTCueTimings(start=start, end=end) + assert cue_timings.start == start + assert cue_timings.end == end + assert str(cue_timings) == "00:10.005 --> 00:14.007" + + """Test invalid cue timings with end timestamp before start.""" + start = WebVTTTimestamp(raw="00:10.700") + end = WebVTTTimestamp(raw="00:10.500") + with pytest.raises(ValidationError) as excinfo: + WebVTTCueTimings(start=start, end=end) + assert "End timestamp must be greater than start timestamp" in str(excinfo.value) + + """Test invalid cue timings with missing end.""" + start = WebVTTTimestamp(raw="00:10.500") + with pytest.raises(ValidationError) as excinfo: + WebVTTCueTimings(start=start) # type: ignore[call-arg] + assert "Field required" in str(excinfo.value) + + """Test invalid cue timings with missing start.""" + end = WebVTTTimestamp(raw="00:10.500") + with pytest.raises(ValidationError) as excinfo: + WebVTTCueTimings(end=end) # type: ignore[call-arg] + assert "Field required" in str(excinfo.value) + + """Test with valid text.""" + valid_text = "This is a valid cue text span." + span = WebVTTCueTextSpan(text=valid_text) + assert span.text == valid_text + assert str(span) == valid_text + + """Test with text containing newline characters.""" + invalid_text = "This cue text span\ncontains a newline." + with pytest.raises(ValidationError): + WebVTTCueTextSpan(text=invalid_text) + + """Test with text containing ampersand.""" + invalid_text = "This cue text span contains &." + with pytest.raises(ValidationError): + WebVTTCueTextSpan(text=invalid_text) + invalid_text = "An invalid &foo; entity" + with pytest.raises(ValidationError): + WebVTTCueTextSpan(text=invalid_text) + valid_text = "My favorite book is Pride & Prejudice" + span = WebVTTCueTextSpan(text=valid_text) + assert span.text == valid_text + + """Test with text containing less-than sign.""" + invalid_text = "This cue text span contains <." + with pytest.raises(ValidationError): + WebVTTCueTextSpan(text=invalid_text) + + """Test with empty text.""" + with pytest.raises(ValidationError): + WebVTTCueTextSpan(text="") + + """Test that annotation validation works correctly.""" + valid_annotation = "valid-annotation" + invalid_annotation = "invalid\nannotation" + with pytest.raises(ValidationError): + WebVTTCueSpanStartTagAnnotated(name="v", annotation=invalid_annotation) + assert WebVTTCueSpanStartTagAnnotated(name="v", annotation=valid_annotation) + + """Test that classes validation works correctly.""" + annotation = "speaker name" + valid_classes = ["class1", "class2"] + invalid_classes = ["class\nwith\nnewlines", ""] + with pytest.raises(ValidationError): + WebVTTCueSpanStartTagAnnotated( + name="v", annotation=annotation, classes=invalid_classes + ) + assert WebVTTCueSpanStartTagAnnotated( + name="v", annotation=annotation, classes=valid_classes + ) + + """Test that components validation works correctly.""" + annotation = "speaker name" + valid_components = [ + WebVTTCueComponentWithTerminator( + component=WebVTTCueTextSpan(text="random text") + ) + ] + invalid_components = [123, "not a component"] + with pytest.raises(ValidationError): + WebVTTCueInternalText(components=invalid_components) + assert WebVTTCueInternalText(components=valid_components) + + """Test valid cue voice spans.""" + cue_span = WebVTTCueVoiceSpan( + start_tag=WebVTTCueSpanStartTagAnnotated( + name="v", annotation="speaker", classes=["loud", "clear"] + ), + internal_text=WebVTTCueInternalText( + components=[ + WebVTTCueComponentWithTerminator( + component=WebVTTCueTextSpan(text="random text") + ) + ] + ), + ) + expected_str = "random text" + assert str(cue_span) == expected_str + + cue_span = WebVTTCueVoiceSpan( + start_tag=WebVTTCueSpanStartTagAnnotated(name="v", annotation="speaker"), + internal_text=WebVTTCueInternalText( + components=[ + WebVTTCueComponentWithTerminator( + component=WebVTTCueTextSpan(text="random text") + ) + ] + ), + ) + expected_str = "random text" + assert str(cue_span) == expected_str + + +def test_webvttcueblock_parse() -> None: + """Test the method parse of _WebVTTCueBlock class.""" + raw: str = ( + "04:02.500 --> 04:05.000\n" "J’ai commencé le basket à l'âge de 13, 14 ans\n" + ) + block: WebVTTCueBlock = WebVTTCueBlock.parse(raw) + assert str(block.timings) == "04:02.500 --> 04:05.000" + assert len(block.payload) == 1 + assert isinstance(block.payload[0], WebVTTCueComponentWithTerminator) + assert isinstance(block.payload[0].component, WebVTTCueTextSpan) + assert ( + block.payload[0].component.text + == "J’ai commencé le basket à l'âge de 13, 14 ans" + ) + assert raw == str(block) + + raw = ( + "04:05.001 --> 04:07.800\n" + "Sur les playground, ici à Montpellier\n" + ) + block = WebVTTCueBlock.parse(raw) + assert str(block.timings) == "04:05.001 --> 04:07.800" + assert len(block.payload) == 3 + assert isinstance(block.payload[0], WebVTTCueComponentWithTerminator) + assert isinstance(block.payload[0].component, WebVTTCueTextSpan) + assert block.payload[0].component.text == "Sur les " + assert isinstance(block.payload[1], WebVTTCueComponentWithTerminator) + assert isinstance(block.payload[1].component, WebVTTCueItalicSpan) + assert len(block.payload[1].component.internal_text.components) == 1 + lang_span = block.payload[1].component.internal_text.components[0].component + assert isinstance(lang_span, WebVTTCueLanguageSpan) + assert isinstance( + lang_span.internal_text.components[0].component, WebVTTCueTextSpan + ) + assert lang_span.internal_text.components[0].component.text == "playground" + assert isinstance(block.payload[2], WebVTTCueComponentWithTerminator) + assert isinstance(block.payload[2].component, WebVTTCueTextSpan) + assert block.payload[2].component.text == ", ici à Montpellier" + assert raw == str(block) + + +def test_webvtt_file() -> None: + """Test WebVTT files.""" + with open("./test/data/webvtt/webvtt_example_01.vtt", encoding="utf-8") as f: + content = f.read() + vtt = WebVTTFile.parse(content) + assert len(vtt) == 13 + block = vtt.cue_blocks[11] + assert str(block.timings) == "00:32.500 --> 00:33.500" + assert len(block.payload) == 1 + cue_span = block.payload[0] + assert isinstance(cue_span.component, WebVTTCueVoiceSpan) + assert cue_span.component.start_tag.annotation == "Neil deGrasse Tyson" + assert not cue_span.component.start_tag.classes + assert len(cue_span.component.internal_text.components) == 1 + comp = cue_span.component.internal_text.components[0] + assert isinstance(comp.component, WebVTTCueItalicSpan) + assert len(comp.component.internal_text.components) == 1 + comp2 = comp.component.internal_text.components[0] + assert isinstance(comp2.component, WebVTTCueTextSpan) + assert comp2.component.text == "Laughs" + + with open("./test/data/webvtt/webvtt_example_02.vtt", encoding="utf-8") as f: + content = f.read() + vtt = WebVTTFile.parse(content) + assert len(vtt) == 4 + reverse = ( + "WEBVTT\n\nNOTE Copyright © 2019 World Wide Web Consortium. " + "https://www.w3.org/TR/webvtt1/\n\n" + ) + reverse += "\n".join([str(block) for block in vtt.cue_blocks]) + assert content == reverse.rstrip() + + with open("./test/data/webvtt/webvtt_example_03.vtt", encoding="utf-8") as f: + content = f.read() + vtt = WebVTTFile.parse(content) + assert len(vtt) == 13 + for block in vtt: + assert block.identifier + block = vtt.cue_blocks[0] + assert block.identifier == "62357a1d-d250-41d5-a1cf-6cc0eeceffcc/15-0" + assert str(block.timings) == "00:00:04.963 --> 00:00:08.571" + assert len(block.payload) == 1 + assert isinstance(block.payload[0].component, WebVTTCueVoiceSpan) + block = vtt.cue_blocks[2] + assert block.identifier == "62357a1d-d250-41d5-a1cf-6cc0eeceffcc/16-0" + assert str(block.timings) == "00:00:10.683 --> 00:00:11.563" + assert len(block.payload) == 1 + assert isinstance(block.payload[0].component, WebVTTCueTextSpan) + assert block.payload[0].component.text == "Good." + + with open("./test/data/webvtt/webvtt_example_04.vtt", encoding="utf-8") as f: + content = f.read() + vtt = WebVTTFile.parse(content) + assert len(vtt) == 2 + block = vtt.cue_blocks[1] + assert len(block.payload) == 5 + assert str(block) == ( + "00:05.000 --> 00:09.000\n" + "— It will perforate your stomach.\n" + "— You could die.\n" + "This is true.\n" + ) + + +def test_webvtt_cue_language_span_start_tag(): + WebVTTCueLanguageSpanStartTag.model_validate_json('{"annotation": "en"}') + WebVTTCueLanguageSpanStartTag.model_validate_json('{"annotation": "en-US"}') + WebVTTCueLanguageSpanStartTag.model_validate_json('{"annotation": "zh-Hant"}') + with pytest.raises(ValidationError, match="BCP 47"): + WebVTTCueLanguageSpanStartTag.model_validate_json('{"annotation": "en_US"}') + with pytest.raises(ValidationError, match="BCP 47"): + WebVTTCueLanguageSpanStartTag.model_validate_json('{"annotation": "123-de"}')