80 lines
2.3 KiB
Python
80 lines
2.3 KiB
Python
#!/usr/bin/env python3
|
||
# Script Version: 0.4
|
||
# Description: Semantic search over local embeddings.json with content preview and optional file copy
|
||
|
||
import json
|
||
import torch
|
||
import os
|
||
import shutil
|
||
import numpy as np
|
||
from sentence_transformers import SentenceTransformer
|
||
from sklearn.metrics.pairwise import cosine_similarity
|
||
|
||
# Config
|
||
EMBEDDING_FILE = "embeddings.json"
|
||
CONTENT_DIR = "content"
|
||
RESULTS_DIR = "results"
|
||
MODEL_NAME = "all-mpnet-base-v2"
|
||
PREVIEW_LINES = 5 # Number of lines to preview from the matching .txt files
|
||
|
||
# Ensure results directory exists
|
||
os.makedirs(RESULTS_DIR, exist_ok=True)
|
||
|
||
# Load model
|
||
model = SentenceTransformer(MODEL_NAME)
|
||
if torch.cuda.is_available():
|
||
model = model.to("cuda")
|
||
print("[INFO] Running on GPU")
|
||
|
||
# Load stored embeddings
|
||
with open(EMBEDDING_FILE, "r") as f:
|
||
stored_embeddings = json.load(f)
|
||
|
||
# Prompt user
|
||
query = input("\U0001F50D Enter your search query: ").strip()
|
||
|
||
# Embed query
|
||
query_embedding = model.encode(query)
|
||
|
||
# Compute cosine similarities
|
||
results = []
|
||
for filename, embedding in stored_embeddings.items():
|
||
score = cosine_similarity([query_embedding], [embedding])[0][0]
|
||
results.append((filename, score))
|
||
|
||
# Sort and display top result(s)
|
||
results.sort(key=lambda x: x[1], reverse=True)
|
||
|
||
copied_files = []
|
||
|
||
print("\n\U0001F4C2 Top matches:")
|
||
for fname, score in results[:3]:
|
||
print(f"\n{fname} → score: {score:.4f}")
|
||
txt_path = os.path.join(CONTENT_DIR, fname)
|
||
if os.path.exists(txt_path):
|
||
print("Preview:")
|
||
with open(txt_path, "r", encoding="utf-8") as f:
|
||
for i, line in enumerate(f):
|
||
print(" " + line.strip())
|
||
if i + 1 >= PREVIEW_LINES:
|
||
break
|
||
|
||
# Ask user if they want to copy the file
|
||
should_copy = input(f"📄 Copy '{fname}' to '{RESULTS_DIR}'? [y/N]: ").strip().lower()
|
||
if should_copy == "y":
|
||
dest_path = os.path.join(RESULTS_DIR, fname)
|
||
shutil.copyfile(txt_path, dest_path)
|
||
copied_files.append(fname)
|
||
print(f"[INFO] File copied to {dest_path}")
|
||
else:
|
||
print("[WARN] Source file not found for preview.")
|
||
|
||
# Final summary
|
||
if copied_files:
|
||
print("\n✅ Summary of copied files:")
|
||
for f in copied_files:
|
||
print(f" - {f}")
|
||
else:
|
||
print("\nℹ️ No files were copied.")
|
||
|