Files
at1-workstation-scripts/semantic_search.py
2025-04-13 16:05:19 +02:00

80 lines
2.3 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# Script Version: 0.4
# Description: Semantic search over local embeddings.json with content preview and optional file copy
import json
import torch
import os
import shutil
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
# Config
EMBEDDING_FILE = "embeddings.json"
CONTENT_DIR = "content"
RESULTS_DIR = "results"
MODEL_NAME = "all-mpnet-base-v2"
PREVIEW_LINES = 5 # Number of lines to preview from the matching .txt files
# Ensure results directory exists
os.makedirs(RESULTS_DIR, exist_ok=True)
# Load model
model = SentenceTransformer(MODEL_NAME)
if torch.cuda.is_available():
model = model.to("cuda")
print("[INFO] Running on GPU")
# Load stored embeddings
with open(EMBEDDING_FILE, "r") as f:
stored_embeddings = json.load(f)
# Prompt user
query = input("\U0001F50D Enter your search query: ").strip()
# Embed query
query_embedding = model.encode(query)
# Compute cosine similarities
results = []
for filename, embedding in stored_embeddings.items():
score = cosine_similarity([query_embedding], [embedding])[0][0]
results.append((filename, score))
# Sort and display top result(s)
results.sort(key=lambda x: x[1], reverse=True)
copied_files = []
print("\n\U0001F4C2 Top matches:")
for fname, score in results[:3]:
print(f"\n{fname} → score: {score:.4f}")
txt_path = os.path.join(CONTENT_DIR, fname)
if os.path.exists(txt_path):
print("Preview:")
with open(txt_path, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
print(" " + line.strip())
if i + 1 >= PREVIEW_LINES:
break
# Ask user if they want to copy the file
should_copy = input(f"📄 Copy '{fname}' to '{RESULTS_DIR}'? [y/N]: ").strip().lower()
if should_copy == "y":
dest_path = os.path.join(RESULTS_DIR, fname)
shutil.copyfile(txt_path, dest_path)
copied_files.append(fname)
print(f"[INFO] File copied to {dest_path}")
else:
print("[WARN] Source file not found for preview.")
# Final summary
if copied_files:
print("\n✅ Summary of copied files:")
for f in copied_files:
print(f" - {f}")
else:
print("\n No files were copied.")