Skip to content
This repository has been archived by the owner on Sep 11, 2024. It is now read-only.

Commit

Permalink
Merge pull request #88 from climatepolicyradar/feature/upgrade-pydantic
Browse files Browse the repository at this point in the history
Update Pydantic Version
  • Loading branch information
THOR300 authored Nov 8, 2023
2 parents 0a2bc68 + 7d7a3c0 commit 207881e
Show file tree
Hide file tree
Showing 10 changed files with 317 additions and 240 deletions.
210 changes: 143 additions & 67 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ packages = [{include = "cpr_data_access", from = "src"}]

[tool.poetry.dependencies]
python = "^3.9"
pydantic = "^1.10.2"
pydantic = "^2.4.0"
boto3 = "^1.26.16"
tqdm = "^4.64.1"
aws-error-utils = "^2.7.0"
Expand Down
9 changes: 4 additions & 5 deletions src/cpr_data_access/data_adaptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from abc import ABC, abstractmethod
from typing import List, Optional
from pathlib import Path

from tqdm.auto import tqdm

from cpr_data_access.parser_models import BaseParserOutput
Expand Down Expand Up @@ -64,7 +63,7 @@ def load_dataset(
for filename in tqdm(s3_objects[:limit]):
if filename.endswith(".json"):
parsed_files.append(
BaseParserOutput.parse_raw(
BaseParserOutput.model_validate_json(
_s3_object_read_text(f"{dataset_key}/{filename.split('/')[-1]}")
)
)
Expand All @@ -83,7 +82,7 @@ def get_by_id(
"""

try:
return BaseParserOutput.parse_raw(
return BaseParserOutput.model_validate_json(
_s3_object_read_text(f"s3://{dataset_key}/{document_id}.json")
)
except ValueError as e:
Expand Down Expand Up @@ -143,7 +142,7 @@ def _load_files(file_paths: list[Path], batch_idx: int, num_batches: int):
raw_files,
desc=f"Loading files from directory in batch {batch_idx + 1}/{num_batches}",
):
parsed_files.append(BaseParserOutput.parse_raw(raw_file_text))
parsed_files.append(BaseParserOutput.model_validate_json(raw_file_text))

return parsed_files

Expand All @@ -170,4 +169,4 @@ def get_by_id(
if not file_path.exists():
return None

return BaseParserOutput.parse_raw(file_path.read_text())
return BaseParserOutput.model_validate_json(file_path.read_text())
133 changes: 84 additions & 49 deletions src/cpr_data_access/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
AnyHttpUrl,
BaseModel,
Field,
constr,
StringConstraints,
NonNegativeInt,
PrivateAttr,
root_validator,
model_validator,
)
from tqdm.auto import tqdm
import numpy as np
Expand Down Expand Up @@ -115,7 +115,7 @@ class KnowledgeBaseIDs(BaseModel):
"""Store for knowledge base IDs."""

wikipedia_title: Optional[str]
wikidata_id: Optional[Annotated[str, constr(regex=r"^Q\d+$")]] # type: ignore
wikidata_id: Optional[Annotated[str, StringConstraints(pattern=r"^Q\d+$")]] # type: ignore

class Config:
"""
Expand Down Expand Up @@ -162,19 +162,19 @@ def __hash__(self):
"""Make hashable."""
return hash((type(self),) + tuple(self.__dict__.values()))

@root_validator
def _is_valid(cls, values):
@model_validator(mode="after")
def _is_valid(self):
"""Check that the span is valid, and convert label and id to a consistent format."""

if values["start_idx"] + len(values["text"]) != values["end_idx"]:
if self.start_idx + len(self.text) != self.end_idx:
raise ValueError(
"Values of 'start_idx', 'end_idx' and 'text' are not consistent. 'end_idx' should be 'start_idx' + len('text')."
)

values["type"] = values["type"].upper().replace(" ", "_")
values["id"] = values["id"].upper().replace(" ", "_")
self.type = self.type.upper().replace(" ", "_")
self.id = self.id.upper().replace(" ", "_")

return values
return self


class TextBlock(BaseModel):
Expand All @@ -185,11 +185,11 @@ class Config: # noqa: D106

text: Sequence[str]
text_block_id: str
language: Optional[str]
language: Optional[str] = None
type: BlockType
type_confidence: Annotated[float, Field(ge=0, le=1)]
page_number: Annotated[int, Field(ge=-1)]
coords: Optional[List[Tuple[float, float]]]
coords: Optional[List[Tuple[float, float]]] = None
_spans: list[Span] = PrivateAttr(default_factory=list)

def to_string(self) -> str:
Expand Down Expand Up @@ -382,7 +382,7 @@ class PageMetadata(BaseModel):
class BaseMetadata(BaseModel):
"""Metadata that we expect to appear in every document. Should be kept minimal."""

geography: Optional[str]
geography: Optional[str] = None
publication_ts: Optional[datetime.datetime]


Expand All @@ -391,16 +391,18 @@ class BaseDocument(BaseModel):

document_id: str
document_name: str
document_source_url: Optional[AnyHttpUrl]
document_content_type: Optional[str]
document_md5_sum: Optional[str]
languages: Optional[Sequence[str]]
document_source_url: Optional[AnyHttpUrl] = None
document_content_type: Optional[str] = None
document_md5_sum: Optional[str] = None
languages: Optional[Sequence[str]] = None
translated: bool
has_valid_text: bool
text_blocks: Optional[Sequence[TextBlock]] # None if there is no content type
text_blocks: Optional[
Sequence[TextBlock]
] = None # None if there is no content type
page_metadata: Optional[
Sequence[PageMetadata]
] # Properties such as page numbers and dimensions for paged documents
] = None # Properties such as page numbers and dimensions for paged documents
document_metadata: BaseMetadata
# The current fields are set in the document parser:
# https://github.com/climatepolicyradar/navigator-document-parser/blob/5a2872389a85e9f81cdde148b388383d7490807e/cli/parse_pdfs.py#L435
Expand Down Expand Up @@ -437,11 +439,11 @@ def from_parser_output(
elif parser_document.document_content_type == CONTENT_TYPE_PDF:
has_valid_text = True
text_blocks = [
TextBlock.parse_obj(block)
TextBlock.model_validate(block.model_dump())
for block in (parser_document.pdf_data.text_blocks) # type: ignore
]
page_metadata = [
PageMetadata.parse_obj(meta)
PageMetadata.model_validate(meta.model_dump())
for meta in parser_document.pdf_data.page_metadata # type: ignore
]

Expand All @@ -450,7 +452,9 @@ def from_parser_output(
f"Unsupported content type: {parser_document.document_content_type}"
)

parser_document_data = parser_document.dict(exclude={"html_data", "pdf_data"})
parser_document_data = parser_document.model_dump(
exclude={"html_data", "pdf_data"}
)
metadata = {
"document_metadata": parser_document.document_metadata,
"pipeline_metadata": parser_document.pipeline_metadata,
Expand All @@ -461,7 +465,7 @@ def from_parser_output(
"has_valid_text": has_valid_text,
}

return cls.parse_obj(parser_document_data | metadata | text_and_page_data)
return cls.model_validate(parser_document_data | metadata | text_and_page_data)

@classmethod
def load_from_remote(
Expand Down Expand Up @@ -669,15 +673,15 @@ class CPRDocumentMetadata(BaseModel):
source: str
type: str
sectors: Sequence[str]
collection_id: Optional[str]
collection_name: Optional[str]
collection_id: Optional[str] = None
collection_name: Optional[str] = None
family_id: str
family_name: str
family_slug: str
role: Optional[str]
variant: Optional[str]
role: Optional[str] = None
variant: Optional[str] = None
status: str
publication_ts: Optional[datetime.datetime]
publication_ts: Optional[datetime.datetime] = None


class CPRDocument(BaseDocument):
Expand All @@ -692,7 +696,7 @@ class CPRDocument(BaseDocument):

document_description: str
document_slug: str
document_cdn_object: Optional[str]
document_cdn_object: Optional[str] = None
document_metadata: CPRDocumentMetadata


Expand All @@ -702,16 +706,16 @@ class GSTDocumentMetadata(BaseModel):
source: str
author: Sequence[str]
geography_iso: str
types: Optional[Sequence[str]]
types: Optional[Sequence[str]] = None
date: datetime.date
link: Optional[str]
link: Optional[str] = None
author_is_party: bool
collection_id: Optional[str]
collection_id: Optional[str] = None
family_id: str
family_name: str
family_slug: str
role: Optional[str]
variant: Optional[str]
role: Optional[str] = None
variant: Optional[str] = None
status: str


Expand Down Expand Up @@ -801,8 +805,8 @@ def _document_id_idx_hash_map(self) -> dict[str, set[int]]:
def metadata_df(self) -> pd.DataFrame:
"""Return a dataframe of document metadata"""
metadata = [
doc.dict(exclude={"text_blocks", "document_metadata"})
| doc.document_metadata.dict()
doc.model_dump(exclude={"text_blocks", "document_metadata"})
| doc.document_metadata.model_dump()
| {"num_text_blocks": len(doc.text_blocks) if doc.text_blocks else 0}
| {"num_pages": len(doc.page_metadata) if doc.page_metadata else 0}
for doc in self.documents
Expand Down Expand Up @@ -900,7 +904,7 @@ def add_metadata(
else:
continue

doc_dict = document.dict(
doc_dict = document.model_dump(
exclude={"document_metadata", "_text_block_idx_hash_map"}
)
new_metadata_dict = metadata_df.loc[
Expand All @@ -920,13 +924,25 @@ def add_metadata(
for s in new_metadata_dict.get("Sectors", "").split(";")
],
status=new_metadata_dict.pop("CPR Document Status"),
collection_id=new_metadata_dict.pop("CPR Collection ID"),
collection_name=new_metadata_dict.pop("Collection name"),
collection_id=(
new_metadata_dict.pop("CPR Collection ID")
if isinstance(new_metadata_dict.get("CPR Collection ID"), str)
else None
),
collection_name=(
new_metadata_dict.pop("Collection name")
if isinstance(new_metadata_dict.get("Collection name"), str)
else None
),
family_id=new_metadata_dict.pop("CPR Family ID"),
family_name=new_metadata_dict.pop("Family name"),
family_slug=new_metadata_dict.pop("CPR Family Slug"),
role=new_metadata_dict.pop("Document role"),
variant=new_metadata_dict.pop("Document variant"),
variant=(
new_metadata_dict.pop("Document variant")
if isinstance(new_metadata_dict.get("Document variant"), str)
else None
),
# NOTE: we incorrectly use the "publication_ts" value from the parser output rather than the correct
# document date (calculated from events in product). When we upgrade to Vespa we should use the correct
# date.
Expand Down Expand Up @@ -956,12 +972,20 @@ def add_metadata(
date=new_metadata_dict.pop("Date"),
link=new_metadata_dict.pop("Documents"),
author_is_party=new_metadata_dict.pop("Author Type") == "Party",
collection_id=new_metadata_dict.pop("CPR Collection ID"),
collection_id=(
new_metadata_dict.pop("CPR Collection ID")
if isinstance(new_metadata_dict.get("CPR Collection ID"), str)
else None
),
family_id=new_metadata_dict.pop("CPR Family ID"),
family_name=new_metadata_dict.pop("Family Name"),
family_slug=new_metadata_dict.pop("CPR Family Slug"),
role=new_metadata_dict.pop("Document Role"),
variant=new_metadata_dict.pop("Document Variant"),
variant=(
new_metadata_dict.pop("Document Variant")
if isinstance(new_metadata_dict.get("Document Variant"), str)
else None
),
status=new_metadata_dict.pop("CPR Document Status"),
author=[
s.strip() for s in new_metadata_dict.pop("Author").split(",")
Expand Down Expand Up @@ -1094,7 +1118,7 @@ def get_all_text_blocks(
for doc in self.documents:
if doc.text_blocks is not None:
if with_document_context:
doc_dict = doc.dict(exclude={"text_blocks"})
doc_dict = doc.model_dump(exclude={"text_blocks"})
for block in doc.text_blocks:
output_values.append((block, doc_dict))
else:
Expand All @@ -1114,13 +1138,15 @@ def _doc_to_text_block_dicts(self, document: AnyDocument) -> List[Dict[str, Any]
return []

doc_metadata_dict = (
document.dict(exclude={"text_blocks", "page_metadata", "document_metadata"})
| document.document_metadata.dict()
document.model_dump(
exclude={"text_blocks", "page_metadata", "document_metadata"}
)
| document.document_metadata.model_dump()
)

return [
doc_metadata_dict
| block.dict(exclude={"text"})
| block.model_dump(exclude={"text"})
| {"text": block.to_string(), "block_index": idx}
for idx, block in enumerate(document.text_blocks)
]
Expand Down Expand Up @@ -1157,10 +1183,19 @@ def to_huggingface(
citation=citation or "",
)

mapping = {
key: [d.get(key, None) for d in text_block_dicts] for key in dict_keys
}

plain_urls = [
source_url.path if source_url else None
for source_url in mapping["document_source_url"]
]

mapping["document_source_url"] = plain_urls

huggingface_dataset = HFDataset.from_dict(
mapping={
key: [d.get(key, None) for d in text_block_dicts] for key in dict_keys
},
mapping=mapping,
info=dataset_info,
)

Expand Down Expand Up @@ -1235,7 +1270,7 @@ def _from_huggingface_parquet(
"source": "GST" if self.document_model == GSTDocument else "CPR"
}

doc = self.document_model.parse_obj(
doc = self.document_model.model_validate(
doc_fields
| {
"document_metadata": doc_metadata_dict,
Expand Down
Loading

0 comments on commit 207881e

Please sign in to comment.