From 7d84262169186c35e13a5595fe25cca910bea793 Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Sat, 9 Nov 2024 18:24:32 -0800 Subject: [PATCH 1/3] feat: performant resolve --- .gitignore | 5 +- Makefile | 31 ++- docetl/operations/fast_resolve.py | 345 ++++++++++++++++++++++++++ docetl/operations/resolve.py | 1 + docetl/rust/Cargo.lock | 390 ++++++++++++++++++++++++++++++ docetl/rust/Cargo.toml | 13 + docetl/rust/__init__.py | 1 + docetl/rust/src/lib.rs | 310 ++++++++++++++++++++++++ pyproject.toml | 12 +- tests/test_fast_resolve.py | 228 +++++++++++++++++ 10 files changed, 1325 insertions(+), 11 deletions(-) create mode 100644 docetl/operations/fast_resolve.py create mode 100644 docetl/rust/Cargo.lock create mode 100644 docetl/rust/Cargo.toml create mode 100644 docetl/rust/__init__.py create mode 100644 docetl/rust/src/lib.rs create mode 100644 tests/test_fast_resolve.py diff --git a/.gitignore b/.gitignore index b1d0b6b1..86fd079f 100644 --- a/.gitignore +++ b/.gitignore @@ -49,4 +49,7 @@ website/.vercel # typescript website/*.tsbuildinfo -website/next-env.d.ts \ No newline at end of file +website/next-env.d.ts + +# Rust +*target/ \ No newline at end of file diff --git a/Makefile b/Makefile index bc147709..5c4213b3 100644 --- a/Makefile +++ b/Makefile @@ -1,13 +1,34 @@ # Load environment variables from .env file include .env -.PHONY: tests tests-basic lint install mypy update ui-install ui-run +.PHONY: tests tests-basic lint install mypy update ui-install ui-run build-rust develop clean + +# Build commands +build-rust: + maturin develop --release --manifest-path docetl/rust/Cargo.toml + +develop: clean build-rust + poetry install --all-extras + +clean: + rm -rf target/ + rm -rf docetl/rust/target/ + rm -f docetl/resolver/resolver*.so + find . -type d -name "__pycache__" -exec rm -rf {} + + find . -type f -name "*.pyc" -delete + find . -type f -name "*.pyo" -delete + find . -type f -name "*.so" -delete + +# Install command now includes Rust build +install: clean + pip install poetry maturin + $(MAKE) develop # Existing commands -tests: +tests: clean build-rust poetry run pytest -tests-basic: +tests-basic: clean build-rust poetry run pytest tests/basic poetry run pytest tests/test_api.py poetry run pytest tests/test_runner_caching.py @@ -15,10 +36,6 @@ tests-basic: lint: poetry run ruff check docetl/* --fix -install: - pip install poetry - poetry install --all-extras - mypy: poetry run mypy diff --git a/docetl/operations/fast_resolve.py b/docetl/operations/fast_resolve.py new file mode 100644 index 00000000..ce478012 --- /dev/null +++ b/docetl/operations/fast_resolve.py @@ -0,0 +1,345 @@ +from typing import List, Dict, Tuple, Any, Optional +from concurrent.futures import ThreadPoolExecutor +from rich.progress import Progress +from .base import BaseOperation +from docetl_resolver import FastResolver +from rich.console import Console +from rich.status import Status +from jinja2 import Template +import jinja2 +from docetl.operations.utils import RichLoopBar, rich_as_completed + +class FastResolveOperation(BaseOperation): + class schema(BaseOperation.schema): + type: str = "fast_resolve" + comparison_prompt: str + resolution_prompt: str + output: Optional[Dict[str, Any]] = None + embedding_model: Optional[str] = None + resolution_model: Optional[str] = None + comparison_model: Optional[str] = None + blocking_threshold: Optional[float] = None + blocking_keys: Optional[List[str]] = None + embedding_batch_size: Optional[int] = None + compare_batch_size: Optional[int] = None + + def syntax_check(self): + """Check if the config is valid.""" + required_keys = ["comparison_prompt", "output"] + for key in required_keys: + if key not in self.config: + raise ValueError(f"Missing required key '{key}' in FastResolveOperation configuration") + + if "schema" not in self.config["output"]: + raise ValueError("Missing 'schema' in 'output' configuration") + + if not isinstance(self.config["output"]["schema"], dict): + raise TypeError("'schema' in 'output' configuration must be a dictionary") + + if not self.config["output"]["schema"]: + raise ValueError("'schema' in 'output' configuration cannot be empty") + + # Check if the comparison_prompt is a valid Jinja2 template + try: + comparison_template = Template(self.config["comparison_prompt"]) + comparison_vars = comparison_template.environment.parse( + self.config["comparison_prompt"] + ).find_all(jinja2.nodes.Name) + comparison_var_names = {var.name for var in comparison_vars} + if "input1" not in comparison_var_names or "input2" not in comparison_var_names: + raise ValueError( + "'comparison_prompt' must contain both 'input1' and 'input2' variables" + ) + + if "resolution_prompt" in self.config: + reduction_template = Template(self.config["resolution_prompt"]) + reduction_vars = reduction_template.environment.parse( + self.config["resolution_prompt"] + ).find_all(jinja2.nodes.Name) + reduction_var_names = {var.name for var in reduction_vars} + if "inputs" not in reduction_var_names: + raise ValueError("'resolution_prompt' must contain 'inputs' variable") + except Exception as e: + raise ValueError(f"Invalid Jinja2 template: {str(e)}") + + def __init__( + self, + runner: "ConfigWrapper", + config: Dict, + default_model: str, + max_threads: int, + console: Optional[Console] = None, + status: Optional[Status] = None, + is_build: bool = False, + **kwargs, + ): + super().__init__(runner, config, default_model, max_threads, console, status, is_build, **kwargs) + self.resolver = FastResolver( + blocking_threshold=config.get("blocking_threshold", 0.8) + ) + + def batch_embeddings(self, items: List[Dict], batch_size: int = 1000) -> Tuple[List[List[float]], float]: + """Get embeddings for all items in parallel batches.""" + all_embeddings = [] + total_cost = 0 + blocking_keys = self.config.get("blocking_keys", list(items[0].keys())) + + def process_batch(batch): + texts = [ + " ".join(str(item[key]) for key in blocking_keys if key in item) + for item in batch + ] + response = self.runner.api.gen_embedding( + model=self.config.get("embedding_model", "text-embedding-3-small"), + input=texts + ) + return [data["embedding"] for data in response["data"]], response.get("usage", {}).get("total_tokens", 0) * 0.0001 + + with ThreadPoolExecutor(max_workers=self.max_threads) as executor: + futures = [] + for i in range(0, len(items), batch_size): + batch = items[i:i + batch_size] + futures.append(executor.submit(process_batch, batch)) + + for future in rich_as_completed( + futures, + total=len(futures), + desc="Generating embeddings", + console=self.console + ): + embeddings, cost = future.result() + all_embeddings.extend(embeddings) + total_cost += cost + + return all_embeddings, total_cost + + def compare_pair(self, item1: Dict, item2: Dict) -> Tuple[bool, float]: + """Compare two items using the LLM.""" + prompt_template = Template(self.config["comparison_prompt"]) + prompt = prompt_template.render(input1=item1, input2=item2) + + response = self.runner.api.call_llm( + self.config.get("comparison_model", self.default_model), + "compare", + [{"role": "user", "content": prompt}], + {"is_match": "bool"}, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), + bypass_cache=self.config.get("bypass_cache", False), + ) + output = self.runner.api.parse_llm_response( + response.response, + {"is_match": "bool"}, + )[0] + return output["is_match"], response.total_cost + + def process_cluster(self, cluster: List[int], items: List[Dict]) -> Tuple[List[Dict], float]: + """Process a cluster of items to generate a resolved output.""" + if len(cluster) == 1: + return [items[cluster[0]]], 0 + + cluster_items = [items[i] for i in cluster] + reduction_template = Template(self.config["resolution_prompt"]) + resolution_prompt = reduction_template.render(inputs=cluster_items) + + response = self.runner.api.call_llm( + self.config.get("resolution_model", self.default_model), + "resolve", + [{"role": "user", "content": resolution_prompt}], + self.config["output"]["schema"], + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), + bypass_cache=self.config.get("bypass_cache", False), + validation_config=( + { + "val_rule": self.config.get("validate", []), + "validation_fn": self.validation_fn, + } + if self.config.get("validate", None) + else None + ), + ) + + if response.validated: + resolved = self.runner.api.parse_llm_response( + response.response, + self.config["output"]["schema"], + manually_fix_errors=self.manually_fix_errors, + )[0] + + results = [] + for idx in cluster: + item = items[idx].copy() + # Save original values before overwriting + keys_in_output = [k for k in resolved.keys() if k in item.keys()] + item[f"_kv_pairs_preresolve_{self.config['name']}"] = { + k: item[k] for k in keys_in_output + } + item.update(resolved) + results.append(item) + + return results, response.total_cost + + return [], response.total_cost + + def validation_fn(self, response: Dict[str, Any]): + output = self.runner.api.parse_llm_response( + response, + schema=self.config["output"]["schema"], + )[0] + if self.runner.api.validate_output(self.config, output, self.console): + return output, True + return output, False + + def auto_batch(self, num_pairs: int) -> int: + """Calculate optimal batch size based on number of comparisons.""" + # Maximum batch size limit for 4o-mini model + M = 500 + + n = len(self.input_data) + m = num_pairs + + # https://www.wolframalpha.com/input?i=k%28k-1%29%2F2+%2B+%28n-k%29%28k-1%29+%3D+m%2C+solve+for+k + # Two possible solutions for k: + # k = -1/2 sqrt((1 - 2n)^2 - 8m) + n + 1/2 + # k = 1/2 (sqrt((1 - 2n)^2 - 8m) + 2n + 1) + + discriminant = (1 - 2*n)**2 - 8*m + sqrt_discriminant = discriminant ** 0.5 + + k1 = -0.5 * sqrt_discriminant + n + 0.5 + k2 = 0.5 * (sqrt_discriminant + 2*n + 1) + + # Take the maximum viable solution + k = max(k1, k2) + return M if k < 0 else min(int(k), M) + + def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: + """Execute the fast resolve operation.""" + if not input_data: + return [], 0 + + self.input_data = input_data + total_cost = 0 + + # Set up blocking rules + blocking_conditions = self.config.get("blocking_conditions", []) + for condition in blocking_conditions: + # Parse the condition string to extract keys and operation + if "in" in condition: + parts = condition.split("in") + if parts[0].strip().endswith(".lower()") and parts[1].strip().endswith(".lower()"): + key1 = parts[0].split("[")[1].split("]")[0].strip('"\'') + key2 = parts[1].split("[")[1].split("]")[0].strip('"\'') + + if parts[0].strip().startswith("input1"): + self.resolver.add_contains_rule(key1, key2) + else: + self.resolver.add_contained_in_rule(key1, key2) + elif "==" in condition: + parts = condition.split("==") + if parts[0].strip().endswith(".lower()") and parts[1].strip().endswith(".lower()"): + key1 = parts[0].split("[")[1].split("]")[0].strip('"\'') + key2 = parts[1].split("[")[1].split("]")[0].strip('"\'') + self.resolver.add_equals_rule(key1, key2) + + # Get embeddings with configurable batch size + embedding_batch_size = self.config.get("embedding_batch_size", 1000) + embeddings, embedding_cost = self.batch_embeddings(input_data, batch_size=embedding_batch_size) + total_cost += embedding_cost + + # Get comparison pairs from Rust, including blocking rules + comparison_pairs = self.resolver.process_embeddings(embeddings, input_data) + + # Calculate and log statistics + total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2 + comparisons_made = len(comparison_pairs) + comparisons_saved = total_possible_comparisons - comparisons_made + + self.console.log( + f"[green]Comparisons saved by blocking: {comparisons_saved} " + f"({(comparisons_saved / total_possible_comparisons) * 100:.2f}%)[/green]" + ) + self.console.log( + f"[blue]Number of pairs to compare: {comparisons_made}[/blue]" + ) + + # Calculate batch size for comparisons + batch_size = self.config.get("compare_batch_size", self.auto_batch(len(comparison_pairs))) + self.console.log(f"Using compare batch size: {batch_size}") + + # Process comparisons in batches with progress bar + pbar = RichLoopBar( + range(0, len(comparison_pairs), batch_size), + desc=f"Processing batches of {batch_size} LLM comparisons", + console=self.console, + ) + + for i in pbar: + batch = comparison_pairs[i:i + batch_size] + + with ThreadPoolExecutor(max_workers=self.max_threads) as executor: + futures = [] + valid_pairs = [] + + # Pre-filter pairs that might already be in same cluster or processed + for i, j in batch: + if (self.resolver.find_cluster(i) != self.resolver.find_cluster(j) and + not self.resolver.is_processed(i, j)): + futures.append( + executor.submit( + self.compare_pair, + input_data[i], + input_data[j] + ) + ) + valid_pairs.append((i, j)) + + # Process results and merge clusters + for future, (i, j) in zip(futures, valid_pairs): + is_match, cost = future.result() + total_cost += cost + # Mark pair as processed regardless of match result + self.resolver.mark_processed(i, j) + if is_match: + self.resolver.merge_clusters(i, j) + + pbar.update(i//batch_size) + + # Get final clusters + clusters = self.resolver.get_clusters() + + # Calculate and log cluster statistics + num_records_before = len(input_data) + num_clusters_after = len(clusters) + self.console.log(f"Number of records before resolution: {num_records_before}") + self.console.log(f"Number of distinct records after resolution: {num_clusters_after}") + + # Calculate and log self-join selectivity + true_match_count = sum( + len(cluster) * (len(cluster) - 1) // 2 + for cluster in clusters + if len(cluster) > 1 + ) + true_match_selectivity = true_match_count / total_possible_comparisons if total_possible_comparisons > 0 else 0 + self.console.log(f"Self-join selectivity: {true_match_selectivity:.4f}") + + # Process each cluster in parallel with progress + results = [] + with ThreadPoolExecutor(max_workers=self.max_threads) as executor: + futures = [ + executor.submit(self.process_cluster, cluster, input_data) + for cluster in clusters + ] + + for future in rich_as_completed( + futures, + total=len(futures), + desc="Resolving clusters", + console=self.console + ): + cluster_results, cost = future.result() + results.extend(cluster_results) + total_cost += cost + + return results, total_cost \ No newline at end of file diff --git a/docetl/operations/resolve.py b/docetl/operations/resolve.py index 0f896261..0b46cdac 100644 --- a/docetl/operations/resolve.py +++ b/docetl/operations/resolve.py @@ -477,6 +477,7 @@ def auto_batch() -> int: # Update batch_end to prevent overlapping in the next loop batch_end = next_end + better_batch = better_batch[:batch_size] last_processed = batch_end with ThreadPoolExecutor(max_workers=self.max_threads) as executor: diff --git a/docetl/rust/Cargo.lock b/docetl/rust/Cargo.lock new file mode 100644 index 00000000..f024d7b8 --- /dev/null +++ b/docetl/rust/Cargo.lock @@ -0,0 +1,390 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + +[[package]] +name = "docetl_resolver" +version = "0.1.0" +dependencies = [ + "ndarray", + "pyo3", + "rayon", +] + +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + +[[package]] +name = "indoc" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" + +[[package]] +name = "libc" +version = "0.2.162" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" + +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", + "rayon", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" + +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + +[[package]] +name = "proc-macro2" +version = "1.0.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "redox_syscall" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +dependencies = [ + "bitflags", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + +[[package]] +name = "unicode-ident" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" + +[[package]] +name = "unindent" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" diff --git a/docetl/rust/Cargo.toml b/docetl/rust/Cargo.toml new file mode 100644 index 00000000..7857b2b4 --- /dev/null +++ b/docetl/rust/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "docetl_resolver" +version = "0.1.0" +edition = "2021" + +[lib] +name = "docetl_resolver" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.19", features = ["extension-module"] } +ndarray = { version = "0.15", features = ["rayon"] } +rayon = "1.7" \ No newline at end of file diff --git a/docetl/rust/__init__.py b/docetl/rust/__init__.py new file mode 100644 index 00000000..99049ee1 --- /dev/null +++ b/docetl/rust/__init__.py @@ -0,0 +1 @@ +# This file can be empty, it just marks the directory as a Python package \ No newline at end of file diff --git a/docetl/rust/src/lib.rs b/docetl/rust/src/lib.rs new file mode 100644 index 00000000..b9dc1984 --- /dev/null +++ b/docetl/rust/src/lib.rs @@ -0,0 +1,310 @@ +use pyo3::prelude::*; +use ndarray::{Array2, Array1, Axis}; +use std::collections::{HashSet}; +use pyo3::types::{PyDict, PyList}; +use pyo3::Python; +use pyo3::types::PyModule; + +#[derive(Debug, Clone)] +struct ComparisonPair { + i: usize, + j: usize, + similarity: f64, +} + +#[derive(Debug, Clone)] +struct BlockingRule { + rule_type: String, + key1: String, + key2: String, +} + +#[pyclass] +pub struct FastResolver { + #[pyo3(get, set)] + pub blocking_threshold: f64, + parent: Vec, + size: Vec, + clusters: Vec>, + processed_pairs: HashSet<(usize, usize)>, + blocking_rules: Vec, +} + +#[pymethods] +impl FastResolver { + #[new] + fn new(blocking_threshold: f64) -> Self { + FastResolver { + blocking_threshold, + parent: Vec::new(), + size: Vec::new(), + clusters: Vec::new(), + processed_pairs: HashSet::new(), + blocking_rules: Vec::new(), + } + } + + #[staticmethod] + fn compute_similarity_matrix(embeddings: Vec>) -> Vec> { + let n = embeddings.len(); + let n_features = embeddings[0].len(); + + // Convert to ndarray more efficiently using one allocation + let embedding_data: Vec = embeddings.into_iter().flatten().collect(); + let embedding_matrix = Array2::from_shape_vec((n, n_features), embedding_data) + .expect("Shape mismatch in embedding conversion"); + + // Compute norms using axis operation + let norms: Array1 = embedding_matrix.map_axis(Axis(1), |row| { + (row.dot(&row)).sqrt() + }); + + // Compute similarity matrix directly + let dot_products = embedding_matrix.dot(&embedding_matrix.t()); + let norms_matrix = &norms.view().into_shape((n, 1)).unwrap() + * &norms.view().into_shape((1, n)).unwrap(); + + // Divide element-wise and convert to Vec + let similarity = &dot_products / &norms_matrix; + similarity.outer_iter() + .map(|row| row.to_vec()) + .collect() + } + + fn add_contains_rule(&mut self, key1: String, key2: String) -> PyResult<()> { + self.blocking_rules.push(BlockingRule { + rule_type: "contains".to_string(), + key1, + key2, + }); + Ok(()) + } + + fn add_contained_in_rule(&mut self, key1: String, key2: String) -> PyResult<()> { + self.blocking_rules.push(BlockingRule { + rule_type: "contained_in".to_string(), + key1, + key2, + }); + Ok(()) + } + + fn add_equals_rule(&mut self, key1: String, key2: String) -> PyResult<()> { + self.blocking_rules.push(BlockingRule { + rule_type: "equals".to_string(), + key1, + key2, + }); + Ok(()) + } + + fn check_blocking_rules(&self, item1: &PyDict, item2: &PyDict) -> PyResult { + for rule in &self.blocking_rules { + let val1 = match item1.get_item(&rule.key1) { + Some(v) => v.to_string().to_lowercase(), + None => continue, + }; + let val2 = match item2.get_item(&rule.key2) { + Some(v) => v.to_string().to_lowercase(), + None => continue, + }; + + match rule.rule_type.as_str() { + "contains" => { + if val1.contains(&val2) { + return Ok(true); + } + } + "contained_in" => { + if val2.contains(&val1) { + return Ok(true); + } + } + "equals" => { + if val1 == val2 { + return Ok(true); + } + } + _ => continue, + } + } + Ok(false) + } + + fn process_items_with_rules<'py>( + &mut self, + _py: Python<'py>, + items: &'py PyList, + ) -> PyResult> { + let n_samples = items.len(); + let mut blocking_pairs = Vec::new(); + + // Skip if no blocking rules + if self.blocking_rules.is_empty() { + return Ok(blocking_pairs); + } + + // Check each pair against blocking rules + for i in 0..n_samples { + for j in (i+1)..n_samples { + let item1 = items.get_item(i)?.downcast::()?; + let item2 = items.get_item(j)?.downcast::()?; + + if self.check_blocking_rules(item1, item2)? { + // Only add if not already in same cluster and not processed + let root1 = self.find_cluster(i); + let root2 = self.find_cluster(j); + if root1 != root2 && !self.is_processed(i, j) { + blocking_pairs.push((i, j)); + } + } + } + } + + Ok(blocking_pairs) + } + + fn process_embeddings( + &mut self, + embeddings: Vec>, + items: Option<&PyList>, + ) -> PyResult> { + if embeddings.is_empty() { + return Ok(Vec::new()); + } + if !embeddings.iter().all(|v| v.len() == embeddings[0].len()) { + return Err(PyErr::new::( + "All embeddings must have the same dimension" + )); + } + Python::with_gil(|py| { + let sys = PyModule::import(py, "sys")?; + let stdout = sys.getattr("stdout")?; + + let n_samples = embeddings.len(); + stdout.call_method1("write", (format!("Processing embeddings for {} samples...\n", n_samples),))?; + + // Initialize union-find data structures + self.parent = (0..n_samples).collect(); + self.size = vec![1; n_samples]; + self.clusters = vec![HashSet::new(); n_samples]; + for i in 0..n_samples { + self.clusters[i].insert(i); + } + self.processed_pairs.clear(); + + // Get pairs from embeddings + stdout.call_method1("write", ("Computing similarity matrix...\n".to_string(),))?; + let mut all_pairs = Vec::new(); + + // Add embedding-based pairs + let mut pairs = Vec::new(); + let similarity_matrix = Self::compute_similarity_matrix(embeddings); + + stdout.call_method1("write", ("Finding pairs above threshold...\n".to_string(),))?; + for i in 0..n_samples { + for j in (i+1)..n_samples { + let similarity = similarity_matrix[i][j]; + if similarity >= self.blocking_threshold { + pairs.push(ComparisonPair { i, j, similarity }); + } + } + } + + stdout.call_method1("write", + (format!("Found {} pairs above threshold {}\n", pairs.len(), self.blocking_threshold),))?; + + // Sort by similarity descending + pairs.sort_unstable_by(|a, b| { + b.similarity.partial_cmp(&a.similarity).unwrap() + }); + + // Convert to (i,j) pairs and add to all_pairs + all_pairs.extend(pairs.into_iter().map(|pair| (pair.i, pair.j))); + + // Add blocking rule pairs if items were provided + if let Some(items_list) = items { + stdout.call_method1("write", ("Applying blocking rules...\n".to_string(),))?; + let blocking_pairs = self.process_items_with_rules(py, items_list)?; + stdout.call_method1("write", + (format!("Found {} additional pairs from blocking rules\n", blocking_pairs.len()),))?; + all_pairs.extend(blocking_pairs); + } + + // Filter pairs that are already in the same cluster + stdout.call_method1("write", ("Filtering processed pairs...\n".to_string(),))?; + let filtered_pairs: Vec<(usize, usize)> = all_pairs.into_iter() + .filter(|(i, j)| { + let root1 = self.find_cluster(*i); + let root2 = self.find_cluster(*j); + root1 != root2 && !self.is_processed(*i, *j) + }) + .collect(); + + stdout.call_method1("write", + (format!("Final number of pairs to process: {}\n", filtered_pairs.len()),))?; + stdout.call_method0("flush")?; + + Ok(filtered_pairs) + }) + } + + fn find_cluster(&mut self, mut item: usize) -> usize { + while self.parent[item] != item { + // Path compression: Point to grandparent to flatten tree + self.parent[item] = self.parent[self.parent[item]]; + item = self.parent[item]; + } + item + } + + fn merge_clusters(&mut self, item1: usize, item2: usize) -> PyResult<()> { + if item1 >= self.parent.len() || item2 >= self.parent.len() { + return Err(PyErr::new::( + "Invalid cluster index" + )); + } + let mut root1 = self.find_cluster(item1); + let mut root2 = self.find_cluster(item2); + + if root1 != root2 { + // Union by size - attach smaller tree to root of larger tree + if self.size[root1] < self.size[root2] { + std::mem::swap(&mut root1, &mut root2); + } + + // Merge root2 into root1 + self.parent[root2] = root1; + self.size[root1] += self.size[root2]; + + // Merge clusters + let items = self.clusters[root2].drain().collect::>(); + self.clusters[root1].extend(items); + } + + Ok(()) + } + + fn get_clusters(&self) -> PyResult>> { + Ok(self.clusters.iter() + .filter(|c| !c.is_empty()) + .map(|c| c.iter().copied().collect()) + .collect()) + } + + fn is_processed(&self, i: usize, j: usize) -> bool { + let (min_idx, max_idx) = if i < j { (i, j) } else { (j, i) }; + self.processed_pairs.contains(&(min_idx, max_idx)) + } + + fn mark_processed(&mut self, i: usize, j: usize) { + let (min_idx, max_idx) = if i < j { (i, j) } else { (j, i) }; + self.processed_pairs.insert((min_idx, max_idx)); + } +} + +#[pymodule] +fn docetl_resolver(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + Ok(()) +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 00aeadaf..32eaca8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,8 +79,14 @@ ignore_missing_imports = true show_error_codes = true [build-system] -requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" +requires = ["poetry-core>=1.0.0", "maturin>=1.0,<2.0"] +build-backend = "maturin" + +[tool.maturin] +python-source = "docetl" +module-name = "docetl_resolver" +manifest-path = "docetl/rust/Cargo.toml" +develop = true [tool.poetry.plugins."docetl.operation"] map = "docetl.operations.map:MapOperation" @@ -108,4 +114,4 @@ txt_to_string = "docetl.parsing_tools:txt_to_string" docx_to_string = "docetl.parsing_tools:docx_to_string" pptx_to_string = "docetl.parsing_tools:pptx_to_string" azure_di_read = "docetl.parsing_tools:azure_di_read" -paddleocr_pdf_to_string = "docetl.parsing_tools:paddleocr_pdf_to_string" +paddleocr_pdf_to_string = "docetl.parsing_tools:paddleocr_pdf_to_string" \ No newline at end of file diff --git a/tests/test_fast_resolve.py b/tests/test_fast_resolve.py new file mode 100644 index 00000000..0dd5958f --- /dev/null +++ b/tests/test_fast_resolve.py @@ -0,0 +1,228 @@ +import pytest +import random +import string +import time +from docetl.operations.fast_resolve import FastResolveOperation +from docetl.operations.resolve import ResolveOperation + + +@pytest.fixture +def fast_resolve_config(): + return { + "name": "name_email_resolver", + "type": "fast_resolve", + "blocking_threshold": 0.8, + "blocking_keys": ["name", "email"], + "comparison_prompt": """Compare these two entries and determine if they refer to the same person: + Person 1: {{ input1.name }} {{ input1.email }} + Person 2: {{ input2.name }} {{ input2.email }} + Return true if they match, false otherwise.""", + "resolution_prompt": "Given these similar entries, determine the canonical form: {{ inputs }}", + "output": { + "schema": { + "name": "string", + "email": "string" + } + }, + "embedding_model": "text-embedding-3-small", + "comparison_model": "azure/gpt-4o-mini", + "resolution_model": "azure/gpt-4o-mini" + } + + +def generate_large_dataset(num_base_records=100): + """Generate a very large dataset with intentional duplicates and transitive relationships. + + Example of transitivity: + - John Doe <-> Johnny Doe <-> J. Doe (all same email) + - Multiple email variations for same person + - Name variations that chain together + """ + + # Base data to create variations from + first_names = ['John', 'Michael', 'William', 'James', 'David', 'Robert', 'Thomas', 'Christopher'] + last_names = ['Smith', 'Johnson', 'Williams', 'Brown', 'Jones', 'Garcia', 'Miller', 'Davis'] + domains = ['gmail.com', 'yahoo.com', 'hotmail.com', 'outlook.com', 'company.com'] + + data = [] + + # Create base records with intentional relationships + for _ in range(num_base_records): + first = random.choice(first_names) + last = random.choice(last_names) + domain = random.choice(domains) + + # Create base email variations for this person + email_variations = [ + f"{first.lower()}.{last.lower()}@{domain}", + f"{first.lower()[0]}{last.lower()}@{domain}", + f"{first.lower()}{last.lower()[0]}@{domain}", + f"{first.lower()}_{last.lower()}@{domain}" + ] + + # Create name variations that chain together + name_variations = [ + f"{first} {last}", # Standard + f"{first}y {last}", # Diminutive + f"{first[0]}. {last}", # Initial + f"{first} {last[0]}.", # Last initial + f"{first[0]}. {last[0]}.", # Both initials + ] + + # Add middle initials to some variations + middle_initials = random.sample(string.ascii_uppercase, 2) + name_variations.extend([ + f"{first} {mi}. {last}" for mi in middle_initials + ]) + + # Create multiple records with combinations of name/email variations + # This ensures transitive relationships + for name in name_variations: + # Use same email for some variations to create strong links + primary_email = random.choice(email_variations) + data.append({"name": name, "email": primary_email}) + + # Add some variations with different emails + if random.random() < 0.3: + data.append({"name": name, "email": random.choice(email_variations)}) + + # Add typo variations + if random.random() < 0.2: + typo_name = name.replace('i', 'y') if 'i' in name else name + 'n' + data.append({"name": typo_name, "email": primary_email}) + + # Add some completely different email domains for same person + alt_domain = random.choice([d for d in domains if d != domain]) + alt_email = f"{first.lower()}.{last.lower()}@{alt_domain}" + data.append({"name": random.choice(name_variations), "email": alt_email}) + + # Shuffle the dataset + random.shuffle(data) + + # Print some statistics about the dataset + print(f"\nGenerated Dataset Statistics:") + print(f"Total records: {len(data)}") + print(f"Unique names: {len(set(r['name'] for r in data))}") + print(f"Unique emails: {len(set(r['email'] for r in data))}") + print(f"Average variations per base record: {len(data) / num_base_records:.1f}") + + return data + + +@pytest.fixture +def fast_resolve_sample_data(): + # Set random seed for reproducibility + random.seed(42) + return generate_large_dataset() + + +def dont_do_test_fast_resolve_operation( + fast_resolve_config, default_model, fast_resolve_sample_data, api_wrapper +): + + distinct_names = set(result["name"] for result in fast_resolve_sample_data) + distinct_emails = set(result["email"] for result in fast_resolve_sample_data) + print(f"Distinct names in input: {len(distinct_names)}") + print(f"Distinct emails in input: {len(distinct_emails)}") + + operation = FastResolveOperation( + api_wrapper, fast_resolve_config, default_model, 256 + ) + results, cost = operation.execute(fast_resolve_sample_data) + + # Calculate and print some statistics + input_count = len(fast_resolve_sample_data) + output_count = len(results) + distinct_output_names = set(result["name"] for result in results) + distinct_output_emails = set(result["email"] for result in results) + + print(f"\nTest Statistics:") + print(f"Input records: {input_count}") + print(f"Output records: {output_count}") + print(f"Distinct names in output: {len(distinct_output_names)}") + print(f"Distinct emails in output: {len(distinct_output_emails)}") + print(f"Reduction ratio: {(input_count - output_count) / input_count:.2%}") + print(f"Total cost: {cost}") + + # Assertions + assert len(distinct_names) < len(fast_resolve_sample_data) + assert output_count == input_count + assert cost > 0 + + +def test_fast_resolve_operation_empty_input( + fast_resolve_config, default_model, max_threads, api_wrapper +): + operation = FastResolveOperation( + api_wrapper, fast_resolve_config, default_model, max_threads + ) + results, cost = operation.execute([]) + + assert len(results) == 0 + assert cost == 0 + + +@pytest.fixture +def resolve_config(): + return { + "name": "name_email_resolver", + "type": "resolve", + "blocking_keys": ["name", "email"], + "blocking_threshold": 0.8, + "comparison_prompt": "Compare these two entries and determine if they refer to the same person: Person 1: {{ input1 }} Person 2: {{ input2 }} Return true if they match, false otherwise.", + "resolution_prompt": "Given these similar entries, determine the canonical form: {{ inputs }}", + "output": {"schema": {"name": "string", "email": "string"}}, + "embedding_model": "text-embedding-3-small", + "comparison_model": "gpt-4o-mini", + "resolution_model": "gpt-4o-mini" + } + + +def test_compare_resolve_performance( + fast_resolve_config, resolve_config, default_model, api_wrapper +): + """Compare performance between FastResolve and regular Resolve operations.""" + + # Generate a large dataset specifically for this test + large_dataset = generate_large_dataset() + + # Use a smaller subset for the regular resolve to keep test duration reasonable + sample_size = min(len(large_dataset) // 4, 1000) + sample_data = random.sample(large_dataset, sample_size) + + print(f"\nTesting with {len(large_dataset)} records for FastResolve") + print(f"Testing with {len(sample_data)} records for regular Resolve") + + # Test FastResolve with full dataset + start_time = time.time() + fast_operation = FastResolveOperation( + api_wrapper, fast_resolve_config, default_model, 256 + ) + fast_results, fast_cost = fast_operation.execute(large_dataset) + fast_time = time.time() - start_time + + # Test regular Resolve with sample + start_time = time.time() + regular_operation = ResolveOperation( + api_wrapper, resolve_config, default_model, 256 + ) + regular_results, regular_cost = regular_operation.execute(sample_data) + regular_time = time.time() - start_time + + # Scale up regular metrics for fair comparison + scale_factor = len(large_dataset) / len(sample_data) + regular_time_scaled = regular_time * scale_factor + regular_cost_scaled = regular_cost * scale_factor + + # Print performance comparison + print("\nPerformance Comparison (scaled to same dataset size):") + print(f"FastResolve Time: {fast_time:.2f} seconds") + print(f"Regular Resolve Time (scaled): {regular_time_scaled:.2f} seconds") + print(f"FastResolve Cost: ${fast_cost:.4f}") + print(f"Regular Resolve Cost (scaled): ${regular_cost_scaled:.4f}") + print(f"Speed Improvement: {(regular_time_scaled - fast_time) / regular_time_scaled:.1%}") + print(f"Cost Savings: {(regular_cost_scaled - fast_cost) / regular_cost_scaled:.1%}") + + # Assertions + assert fast_time < regular_time_scaled, "FastResolve should be faster than regular Resolve" + assert fast_cost < regular_cost_scaled, "FastResolve should be more cost-effective" \ No newline at end of file From 1d40ac5ee2479a23685e8d420651ca509594e971 Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Sat, 9 Nov 2024 18:58:21 -0800 Subject: [PATCH 2/3] feat: performant resolve --- docetl/operations/fast_resolve.py | 51 ++++++++++--- docetl/rust/Cargo.lock | 99 ++++++++++++++++++++++++- docetl/rust/Cargo.toml | 3 +- docetl/rust/src/lib.rs | 115 ++++++++++++++++++++---------- tests/test_fast_resolve.py | 74 +++++++++---------- 5 files changed, 250 insertions(+), 92 deletions(-) diff --git a/docetl/operations/fast_resolve.py b/docetl/operations/fast_resolve.py index ce478012..e76dc70c 100644 --- a/docetl/operations/fast_resolve.py +++ b/docetl/operations/fast_resolve.py @@ -9,6 +9,8 @@ import jinja2 from docetl.operations.utils import RichLoopBar, rich_as_completed +from rich.prompt import Confirm + class FastResolveOperation(BaseOperation): class schema(BaseOperation.schema): type: str = "fast_resolve" @@ -75,7 +77,9 @@ def __init__( ): super().__init__(runner, config, default_model, max_threads, console, status, is_build, **kwargs) self.resolver = FastResolver( - blocking_threshold=config.get("blocking_threshold", 0.8) + blocking_threshold=config.get("blocking_threshold", None), + debug=config.get("debug", False), + limit_comparisons=config.get("limit_comparisons", None), ) def batch_embeddings(self, items: List[Dict], batch_size: int = 1000) -> Tuple[List[List[float]], float]: @@ -218,6 +222,23 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: """Execute the fast resolve operation.""" if not input_data: return [], 0 + + blocking_threshold = self.config.get("blocking_threshold") + blocking_conditions = self.config.get("blocking_conditions", []) + + if self.status: + self.status.stop() + + if not blocking_threshold and not blocking_conditions: + # Prompt the user for confirmation + if not Confirm.ask( + f"[yellow]Warning: No blocking keys or conditions specified. " + f"This may result in a large number of comparisons. " + f"We recommend specifying at least one blocking key or condition, or using the optimizer to automatically come up with these. " + f"Do you want to continue without blocking?[/yellow]", + console=self.console, + ): + raise ValueError("Operation cancelled by user.") self.input_data = input_data total_cost = 0 @@ -226,22 +247,31 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: blocking_conditions = self.config.get("blocking_conditions", []) for condition in blocking_conditions: # Parse the condition string to extract keys and operation - if "in" in condition: - parts = condition.split("in") + if "==" in condition: + parts = condition.split("==") + if parts[0].strip().endswith(".lower()") and parts[1].strip().endswith(".lower()"): + key1 = parts[0].split("[")[1].split("]")[0].strip('"\'') + key2 = parts[1].split("[")[1].split("]")[0].strip('"\'') + self.resolver.add_equals_rule(key1, key2) + self.console.log(f"Added equals rule: {key1} equals {key2}") + else: + self.console.log(f"Skipped '==' condition - not using .lower(): {condition}") + elif " in " in condition: + parts = condition.split(" in ") if parts[0].strip().endswith(".lower()") and parts[1].strip().endswith(".lower()"): key1 = parts[0].split("[")[1].split("]")[0].strip('"\'') key2 = parts[1].split("[")[1].split("]")[0].strip('"\'') if parts[0].strip().startswith("input1"): self.resolver.add_contains_rule(key1, key2) + self.console.log(f"Added contains rule: {key1} contains {key2}") else: self.resolver.add_contained_in_rule(key1, key2) - elif "==" in condition: - parts = condition.split("==") - if parts[0].strip().endswith(".lower()") and parts[1].strip().endswith(".lower()"): - key1 = parts[0].split("[")[1].split("]")[0].strip('"\'') - key2 = parts[1].split("[")[1].split("]")[0].strip('"\'') - self.resolver.add_equals_rule(key1, key2) + self.console.log(f"Added contained_in rule: {key1} contained in {key2}") + else: + self.console.log(f"Skipped 'in' condition - not using .lower(): {condition}") + else: + self.console.log(f"Skipped condition - no recognized operator: {condition}") # Get embeddings with configurable batch size embedding_batch_size = self.config.get("embedding_batch_size", 1000) @@ -341,5 +371,8 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: cluster_results, cost = future.result() results.extend(cluster_results) total_cost += cost + + if self.status: + self.status.start() return results, total_cost \ No newline at end of file diff --git a/docetl/rust/Cargo.lock b/docetl/rust/Cargo.lock index f024d7b8..5b77a8c3 100644 --- a/docetl/rust/Cargo.lock +++ b/docetl/rust/Cargo.lock @@ -14,6 +14,12 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "cfg-if" version = "1.0.0" @@ -51,6 +57,7 @@ version = "0.1.0" dependencies = [ "ndarray", "pyo3", + "rand", "rayon", ] @@ -60,6 +67,17 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "indoc" version = "1.0.9" @@ -171,6 +189,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "ppv-lite86" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + [[package]] name = "proc-macro2" version = "1.0.89" @@ -226,7 +253,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -237,7 +264,7 @@ checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -249,6 +276,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "rawpointer" version = "0.2.1" @@ -307,6 +364,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "syn" +version = "2.0.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "target-lexicon" version = "0.12.16" @@ -325,6 +393,12 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + [[package]] name = "windows-targets" version = "0.52.6" @@ -388,3 +462,24 @@ name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] diff --git a/docetl/rust/Cargo.toml b/docetl/rust/Cargo.toml index 7857b2b4..cb2388b6 100644 --- a/docetl/rust/Cargo.toml +++ b/docetl/rust/Cargo.toml @@ -10,4 +10,5 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version = "0.19", features = ["extension-module"] } ndarray = { version = "0.15", features = ["rayon"] } -rayon = "1.7" \ No newline at end of file +rayon = "1.7" +rand = "0.8" \ No newline at end of file diff --git a/docetl/rust/src/lib.rs b/docetl/rust/src/lib.rs index b9dc1984..c0d8dabc 100644 --- a/docetl/rust/src/lib.rs +++ b/docetl/rust/src/lib.rs @@ -22,7 +22,11 @@ struct BlockingRule { #[pyclass] pub struct FastResolver { #[pyo3(get, set)] - pub blocking_threshold: f64, + pub blocking_threshold: Option, + #[pyo3(get, set)] + pub debug: bool, + #[pyo3(get, set)] + pub limit_comparisons: Option, parent: Vec, size: Vec, clusters: Vec>, @@ -33,9 +37,11 @@ pub struct FastResolver { #[pymethods] impl FastResolver { #[new] - fn new(blocking_threshold: f64) -> Self { + fn new(blocking_threshold: Option, debug: Option, limit_comparisons: Option) -> Self { FastResolver { blocking_threshold, + debug: debug.unwrap_or(false), + limit_comparisons, parent: Vec::new(), size: Vec::new(), clusters: Vec::new(), @@ -102,11 +108,15 @@ impl FastResolver { for rule in &self.blocking_rules { let val1 = match item1.get_item(&rule.key1) { Some(v) => v.to_string().to_lowercase(), - None => continue, + None => { + continue; + }, }; let val2 = match item2.get_item(&rule.key2) { Some(v) => v.to_string().to_lowercase(), - None => continue, + None => { + continue; + }, }; match rule.rule_type.as_str() { @@ -144,6 +154,20 @@ impl FastResolver { return Ok(blocking_pairs); } + // Print rules once before processing + if self.debug { + println!("\nChecking blocking rules:"); + for rule in &self.blocking_rules { + match rule.rule_type.as_str() { + "contains" => println!("- CONTAINS rule: input1 {} contains input2 {}", rule.key1, rule.key2), + "contained_in" => println!("- CONTAINED_IN rule: input1 {} is contained in input2 {}", rule.key1, rule.key2), + "equals" => println!("- EQUALS rule: input1 {} equals input2 {}", rule.key1, rule.key2), + _ => println!("- Unknown rule type: {}", rule.rule_type), + } + } + println!(""); // Empty line for readability + } + // Check each pair against blocking rules for i in 0..n_samples { for j in (i+1)..n_samples { @@ -151,7 +175,6 @@ impl FastResolver { let item2 = items.get_item(j)?.downcast::()?; if self.check_blocking_rules(item1, item2)? { - // Only add if not already in same cluster and not processed let root1 = self.find_cluster(i); let root2 = self.find_cluster(j); if root1 != root2 && !self.is_processed(i, j) { @@ -178,62 +201,69 @@ impl FastResolver { )); } Python::with_gil(|py| { - let sys = PyModule::import(py, "sys")?; - let stdout = sys.getattr("stdout")?; - let n_samples = embeddings.len(); - stdout.call_method1("write", (format!("Processing embeddings for {} samples...\n", n_samples),))?; - // Initialize union-find data structures + if self.debug { + println!("Processing embeddings for {} samples...", n_samples); + } + + // Initialize only parent and size vectors self.parent = (0..n_samples).collect(); self.size = vec![1; n_samples]; - self.clusters = vec![HashSet::new(); n_samples]; - for i in 0..n_samples { - self.clusters[i].insert(i); - } self.processed_pairs.clear(); - // Get pairs from embeddings - stdout.call_method1("write", ("Computing similarity matrix...\n".to_string(),))?; let mut all_pairs = Vec::new(); + let mut similarity_pairs = Vec::new(); + + if self.debug { + println!("Computing similarity matrix..."); + } - // Add embedding-based pairs - let mut pairs = Vec::new(); let similarity_matrix = Self::compute_similarity_matrix(embeddings); - stdout.call_method1("write", ("Finding pairs above threshold...\n".to_string(),))?; + // Store all pairs with their similarities for i in 0..n_samples { for j in (i+1)..n_samples { let similarity = similarity_matrix[i][j]; - if similarity >= self.blocking_threshold { - pairs.push(ComparisonPair { i, j, similarity }); + if self.blocking_threshold.map_or(true, |t| similarity >= t) { + similarity_pairs.push(ComparisonPair { i, j, similarity }); } } } - stdout.call_method1("write", - (format!("Found {} pairs above threshold {}\n", pairs.len(), self.blocking_threshold),))?; - - // Sort by similarity descending - pairs.sort_unstable_by(|a, b| { + similarity_pairs.sort_unstable_by(|a, b| { b.similarity.partial_cmp(&a.similarity).unwrap() }); - - // Convert to (i,j) pairs and add to all_pairs - all_pairs.extend(pairs.into_iter().map(|pair| (pair.i, pair.j))); // Add blocking rule pairs if items were provided if let Some(items_list) = items { - stdout.call_method1("write", ("Applying blocking rules...\n".to_string(),))?; + if self.debug { + println!("Applying blocking rules..."); + } + let blocking_pairs = self.process_items_with_rules(py, items_list)?; - stdout.call_method1("write", - (format!("Found {} additional pairs from blocking rules\n", blocking_pairs.len()),))?; + + if self.debug { + println!("Found {} pairs from blocking rules", blocking_pairs.len()); + } + all_pairs.extend(blocking_pairs); } - // Filter pairs that are already in the same cluster - stdout.call_method1("write", ("Filtering processed pairs...\n".to_string(),))?; - let filtered_pairs: Vec<(usize, usize)> = all_pairs.into_iter() + // Add similarity pairs after blocking pairs + all_pairs.extend(similarity_pairs.into_iter().map(|pair| (pair.i, pair.j))); + + // Initialize clusters only after all pairs are collected + self.clusters = vec![HashSet::new(); n_samples]; + for i in 0..n_samples { + self.clusters[i].insert(i); + } + + if self.debug { + println!("Filtering processed pairs..."); + } + + let mut filtered_pairs: Vec<(usize, usize)> = all_pairs.into_iter() .filter(|(i, j)| { let root1 = self.find_cluster(*i); let root2 = self.find_cluster(*j); @@ -241,9 +271,18 @@ impl FastResolver { }) .collect(); - stdout.call_method1("write", - (format!("Final number of pairs to process: {}\n", filtered_pairs.len()),))?; - stdout.call_method0("flush")?; + if let Some(limit) = self.limit_comparisons { + if filtered_pairs.len() > limit { + if self.debug { + println!("Limiting to {} pairs out of {}", limit, filtered_pairs.len()); + } + filtered_pairs.truncate(limit); + } + } + + if self.debug { + println!("Final number of pairs to process: {}", filtered_pairs.len()); + } Ok(filtered_pairs) }) diff --git a/tests/test_fast_resolve.py b/tests/test_fast_resolve.py index 0dd5958f..1fd0fbad 100644 --- a/tests/test_fast_resolve.py +++ b/tests/test_fast_resolve.py @@ -12,12 +12,19 @@ def fast_resolve_config(): "name": "name_email_resolver", "type": "fast_resolve", "blocking_threshold": 0.8, + "debug": True, "blocking_keys": ["name", "email"], + "blocking_conditions": [ + "input1['email'].lower() == input2['email'].lower()", # Exact email match + "input1['name'].lower() in input2['name'].lower()", # Name containment + "input2['name'].lower() in input1['name'].lower()" # Reverse name containment + ], "comparison_prompt": """Compare these two entries and determine if they refer to the same person: Person 1: {{ input1.name }} {{ input1.email }} Person 2: {{ input2.name }} {{ input2.email }} Return true if they match, false otherwise.""", - "resolution_prompt": "Given these similar entries, determine the canonical form: {{ inputs }}", + "resolution_prompt": """Given these similar entries, determine the canonical form. + Choose the most complete name and the most professional email address: {{ inputs }}""", "output": { "schema": { "name": "string", @@ -26,7 +33,9 @@ def fast_resolve_config(): }, "embedding_model": "text-embedding-3-small", "comparison_model": "azure/gpt-4o-mini", - "resolution_model": "azure/gpt-4o-mini" + "resolution_model": "azure/gpt-4o-mini", + "embedding_batch_size": 1000, + "limit_comparisons": 1000 } @@ -162,38 +171,17 @@ def test_fast_resolve_operation_empty_input( assert cost == 0 -@pytest.fixture -def resolve_config(): - return { - "name": "name_email_resolver", - "type": "resolve", - "blocking_keys": ["name", "email"], - "blocking_threshold": 0.8, - "comparison_prompt": "Compare these two entries and determine if they refer to the same person: Person 1: {{ input1 }} Person 2: {{ input2 }} Return true if they match, false otherwise.", - "resolution_prompt": "Given these similar entries, determine the canonical form: {{ inputs }}", - "output": {"schema": {"name": "string", "email": "string"}}, - "embedding_model": "text-embedding-3-small", - "comparison_model": "gpt-4o-mini", - "resolution_model": "gpt-4o-mini" - } - def test_compare_resolve_performance( - fast_resolve_config, resolve_config, default_model, api_wrapper + fast_resolve_config, default_model, api_wrapper ): """Compare performance between FastResolve and regular Resolve operations.""" - # Generate a large dataset specifically for this test + # Generate a smaller dataset for testing large_dataset = generate_large_dataset() + print(f"\nTesting with {len(large_dataset)} records") - # Use a smaller subset for the regular resolve to keep test duration reasonable - sample_size = min(len(large_dataset) // 4, 1000) - sample_data = random.sample(large_dataset, sample_size) - - print(f"\nTesting with {len(large_dataset)} records for FastResolve") - print(f"Testing with {len(sample_data)} records for regular Resolve") - - # Test FastResolve with full dataset + # Test FastResolve with blocking rules start_time = time.time() fast_operation = FastResolveOperation( api_wrapper, fast_resolve_config, default_model, 256 @@ -204,25 +192,27 @@ def test_compare_resolve_performance( # Test regular Resolve with sample start_time = time.time() regular_operation = ResolveOperation( - api_wrapper, resolve_config, default_model, 256 + api_wrapper, fast_resolve_config, default_model, 256 ) - regular_results, regular_cost = regular_operation.execute(sample_data) + regular_results, regular_cost = regular_operation.execute(large_dataset) regular_time = time.time() - start_time - # Scale up regular metrics for fair comparison - scale_factor = len(large_dataset) / len(sample_data) - regular_time_scaled = regular_time * scale_factor - regular_cost_scaled = regular_cost * scale_factor - - # Print performance comparison - print("\nPerformance Comparison (scaled to same dataset size):") + # Print detailed performance metrics + print("\nPerformance Comparison:") print(f"FastResolve Time: {fast_time:.2f} seconds") - print(f"Regular Resolve Time (scaled): {regular_time_scaled:.2f} seconds") + print(f"Regular Resolve Time: {regular_time:.2f} seconds") print(f"FastResolve Cost: ${fast_cost:.4f}") - print(f"Regular Resolve Cost (scaled): ${regular_cost_scaled:.4f}") - print(f"Speed Improvement: {(regular_time_scaled - fast_time) / regular_time_scaled:.1%}") - print(f"Cost Savings: {(regular_cost_scaled - fast_cost) / regular_cost_scaled:.1%}") + print(f"Regular Resolve Cost: ${regular_cost:.4f}") + print(f"Speed Improvement: {(regular_time - fast_time) / regular_time:.1%}") + print(f"Cost Savings: {(regular_cost - fast_cost) / regular_cost:.1%}") + + # Additional metrics + print("\nResolution Quality Metrics:") + print(f"FastResolve output records: {len(fast_results)}") + print(f"Distinct names in output: {len(set(r['name'] for r in fast_results))}") + print(f"Distinct emails in output: {len(set(r['email'] for r in fast_results))}") + print(f"Reduction ratio: {(len(large_dataset) - len(fast_results)) / len(large_dataset):.2%}") # Assertions - assert fast_time < regular_time_scaled, "FastResolve should be faster than regular Resolve" - assert fast_cost < regular_cost_scaled, "FastResolve should be more cost-effective" \ No newline at end of file + assert fast_time < regular_time, "FastResolve should be faster than regular Resolve" + assert len(fast_results) <= len(large_dataset), "Output should not be larger than input" \ No newline at end of file From 37522bb6520b80dff7b93bf746b73c176c1eb944 Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Sat, 9 Nov 2024 19:02:27 -0800 Subject: [PATCH 3/3] fix: add pip install maturin to ci and cd --- .github/workflows/ci.yml | 3 ++ .github/workflows/docs.yml | 72 ++++++++++++++++++++------------------ 2 files changed, 40 insertions(+), 35 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3ebcbe8d..13d19c6b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,6 +28,9 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 + - name: Install maturin + run: pip install maturin + - name: Copy environment file run: cp .env.sample .env diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 176e47b8..b3ce3b38 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -1,35 +1,37 @@ - name: docs - on: - push: - branches: - - master - - main - permissions: - contents: write - jobs: - deploy: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Configure Git Credentials - run: | - git config user.name github-actions[bot] - git config user.email 41898282+github-actions[bot]@users.noreply.github.com - - uses: actions/setup-python@v5 - with: - python-version: 3.x - - name: Install Poetry - uses: snok/install-poetry@v1 - - name: Copy environment file - run: cp .env.sample .env - - name: Install dependencies - run: make install - - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV - - uses: actions/cache@v4 - with: - key: mkdocs-material-${{ env.cache_id }} - path: .cache - restore-keys: | - mkdocs-material- - - run: poetry run mkdocs build - - run: poetry run mkdocs gh-deploy --force +name: docs +on: + push: + branches: + - master + - main +permissions: + contents: write +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Configure Git Credentials + run: | + git config user.name github-actions[bot] + git config user.email 41898282+github-actions[bot]@users.noreply.github.com + - uses: actions/setup-python@v5 + with: + python-version: 3.x + - name: Install Poetry + uses: snok/install-poetry@v1 + - name: Install maturin + run: pip install maturin + - name: Copy environment file + run: cp .env.sample .env + - name: Install dependencies + run: make install + - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV + - uses: actions/cache@v4 + with: + key: mkdocs-material-${{ env.cache_id }} + path: .cache + restore-keys: | + mkdocs-material- + - run: poetry run mkdocs build + - run: poetry run mkdocs gh-deploy --force