Skip to content

Commit

Permalink
Improve Langchain document example
Browse files Browse the repository at this point in the history
Loader and splitter are now separated into a Source and a Processor
  • Loading branch information
cbornet committed Oct 9, 2023
1 parent a307600 commit 7ec757b
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 25 deletions.
10 changes: 7 additions & 3 deletions examples/applications/langchain-document-loader/pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@ topics:
- name: "output-topic"
creation-mode: create-if-not-exists
pipeline:
- name: "Load documents and chunk them with LangChain"
- name: "Load documents with LangChain"
type: "python-source"
output: "output-topic"
configuration:
className: langchain_document_loader.LangChainDocumentLoaderSource
className: langchain_agents.LangChainDocumentLoaderSource
load-interval-seconds: 3600
loader-class: WebBaseLoader
loader-args:
web-path: ["https://langstream.ai/"]
- name: "Chunk documents with LangChain"
type: "python-processor"
output: "output-topic"
configuration:
className: langchain_agents.LangChainTextSplitterProcessor
splitter-class: RecursiveCharacterTextSplitter
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
#

import importlib
import json
import time
from typing import List, Dict
from typing import List, Dict, Optional

from langstream import Source
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from langchain.text_splitter import TextSplitter
from langstream import Source, Processor, Record

LOADERS_MODULE = importlib.import_module("langchain.document_loaders")
SPLITTERS_MODULE = importlib.import_module("langchain.text_splitter")
Expand All @@ -27,26 +31,19 @@
class LangChainDocumentLoaderSource(Source):
def __init__(self):
self.load_interval = -1
self.loader = None
self.splitter = None
self.loader: Optional[BaseLoader] = None
self.first_run = True

def init(self, config):
self.load_interval = config.get("load-interval-seconds", -1)
loader_class = config.get("loader-class", "")
if not hasattr(LOADERS_MODULE, loader_class):
raise ValueError(f"Unknown loader: {loader_class}")
kwargs = {k.replace("-", "_"): v for k, v in
config.get("loader-args", {}).items()}
kwargs = {
k.replace("-", "_"): v for k, v in config.get("loader-args", {}).items()
}
self.loader = getattr(LOADERS_MODULE, loader_class)(**kwargs)

splitter_class = config.get("splitter-class", "RecursiveCharacterTextSplitter")
if not hasattr(SPLITTERS_MODULE, splitter_class):
raise ValueError(f"Unknown loader: {splitter_class}")
kwargs = {k.replace("-", "_"): v for k, v in
config.get("splitter-args", {}).items()}
self.splitter = getattr(SPLITTERS_MODULE, splitter_class)(**kwargs)

def read(self) -> List[Dict]:
if not self.first_run:
if self.load_interval == -1:
Expand All @@ -55,5 +52,42 @@ def read(self) -> List[Dict]:
time.sleep(self.load_interval)
else:
self.first_run = False
docs = self.loader.load_and_split(text_splitter=self.splitter)
return [{"value": doc.page_content} for doc in docs]
docs = self.loader.load()
return [
{"value": {"page_content": doc.page_content, "metadata": doc.metadata}}
for doc in docs
]


class LangChainTextSplitterProcessor(Processor):
def __init__(self):
self.splitter: Optional[TextSplitter] = None

def init(self, config):
splitter_class = config.get("splitter-class", "RecursiveCharacterTextSplitter")
if not hasattr(SPLITTERS_MODULE, splitter_class):
raise ValueError(f"Unknown loader: {splitter_class}")
kwargs = {
k.replace("-", "_"): v for k, v in config.get("splitter-args", {}).items()
}
self.splitter = getattr(SPLITTERS_MODULE, splitter_class)(**kwargs)

def process(self, record: Record) -> List[Dict]:
doc = record.value()
if isinstance(doc, str) or isinstance(doc, bytes):
doc = json.loads(doc)

if not isinstance(doc, dict):
raise ValueError(f"Invalid record value received {record.value()}")

chunks = self.splitter.split_documents(
[
Document(
page_content=doc["page_content"], metadata=doc.get("metadata", {})
)
]
)
return [
{"value": {"page_content": chunk.page_content, "metadata": chunk.metadata}}
for chunk in chunks
]
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
# limitations under the License.
#

#
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,49 @@
import os

import yaml
from langstream import SimpleRecord

from langchain_document_loader import LangChainDocumentLoaderSource
from langchain_agents import (
LangChainDocumentLoaderSource,
LangChainTextSplitterProcessor,
)

dir_path = os.path.dirname(os.path.abspath(__file__))


def test():
def test_document_loader():
source = LangChainDocumentLoaderSource()

config = f"""
loader-class: TextLoader
loader-args:
file_path: {dir_path}/lorem.txt
"""

source.init(yaml.safe_load(config))
records = source.read()
assert len(records) == 1
assert records[0]["value"]["page_content"] == "Lorem Ipsum"
assert records[0]["value"]["metadata"]["source"].endswith("lorem.txt")
assert source.read() == []


def test_text_splitter():
processor = LangChainTextSplitterProcessor()

config = """
splitter-class: CharacterTextSplitter
splitter-args:
separator: " "
chunk-size: 1
chunk-overlap: 0
"""

source.init(yaml.safe_load(config))
records = source.read()
assert records[0]["value"] == "Lorem"
assert records[1]["value"] == "Ipsum"
assert source.read() == []
processor.init(yaml.safe_load(config))
records = processor.process(
SimpleRecord({"page_content": "Lorem Ipsum", "metadata": {"source": "foo"}})
)
assert records[0]["value"]["page_content"] == "Lorem"
assert records[0]["value"]["metadata"] == {"source": "foo"}
assert records[1]["value"]["page_content"] == "Ipsum"
assert records[1]["value"]["metadata"] == {"source": "foo"}

0 comments on commit 7ec757b

Please sign in to comment.