mirror of
https://github.com/kennethreitz/langchain.git
synced 2026-06-05 23:00:18 +00:00
e494b0a09f
#### Summary
A new approach to loading source code is implemented:
Each top-level function and class in the code is loaded into separate
documents. Then, an additional document is created with the top-level
code, but without the already loaded functions and classes.
This could improve the accuracy of QA chains over source code.
For instance, having this script:
```
class MyClass:
def __init__(self, name):
self.name = name
def greet(self):
print(f"Hello, {self.name}!")
def main():
name = input("Enter your name: ")
obj = MyClass(name)
obj.greet()
if __name__ == '__main__':
main()
```
The loader will create three documents with this content:
First document:
```
class MyClass:
def __init__(self, name):
self.name = name
def greet(self):
print(f"Hello, {self.name}!")
```
Second document:
```
def main():
name = input("Enter your name: ")
obj = MyClass(name)
obj.greet()
```
Third document:
```
# Code for: class MyClass:
# Code for: def main():
if __name__ == '__main__':
main()
```
A threshold parameter is added to control whether small scripts are
split in this way or not.
At this moment, only Python and JavaScript are supported. The
appropriate parser is determined by examining the file extension.
#### Tests
This PR adds:
- Unit tests
- Integration tests
#### Dependencies
Only one dependency was added as optional (needed for the JavaScript
parser).
#### Documentation
A notebook is added showing how the loader can be used.
#### Who can review?
@eyurtsev @hwchase17
---------
Co-authored-by: rlm <pexpresss31@gmail.com>
48 lines
1.6 KiB
Python
48 lines
1.6 KiB
Python
import ast
|
|
from typing import Any, List
|
|
|
|
from langchain.document_loaders.parsers.language.code_segmenter import CodeSegmenter
|
|
|
|
|
|
class PythonSegmenter(CodeSegmenter):
|
|
def __init__(self, code: str):
|
|
super().__init__(code)
|
|
self.source_lines = self.code.splitlines()
|
|
|
|
def is_valid(self) -> bool:
|
|
try:
|
|
ast.parse(self.code)
|
|
return True
|
|
except SyntaxError:
|
|
return False
|
|
|
|
def _extract_code(self, node: Any) -> str:
|
|
start = node.lineno - 1
|
|
end = node.end_lineno
|
|
return "\n".join(self.source_lines[start:end])
|
|
|
|
def extract_functions_classes(self) -> List[str]:
|
|
tree = ast.parse(self.code)
|
|
functions_classes = []
|
|
|
|
for node in ast.iter_child_nodes(tree):
|
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
|
functions_classes.append(self._extract_code(node))
|
|
|
|
return functions_classes
|
|
|
|
def simplify_code(self) -> str:
|
|
tree = ast.parse(self.code)
|
|
simplified_lines = self.source_lines[:]
|
|
|
|
for node in ast.iter_child_nodes(tree):
|
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
|
start = node.lineno - 1
|
|
simplified_lines[start] = f"# Code for: {simplified_lines[start]}"
|
|
|
|
assert isinstance(node.end_lineno, int)
|
|
for line_num in range(start + 1, node.end_lineno):
|
|
simplified_lines[line_num] = None # type: ignore
|
|
|
|
return "\n".join(line for line in simplified_lines if line is not None)
|