import marimo

__generated_with = "0.18.0"
app = marimo.App(
    width="medium",
    layout_file="layouts/information_retrieval.slides.json",
)


@app.cell
def _():
    import micropip
    return (micropip,)


@app.cell
async def _(micropip):
    import marimo as mo
    import math
    import numpy as np
    import pandas as pd
    import altair as alt

    running_in_web = str(mo.notebook_location()).startswith("http")
    if running_in_web:
        await micropip.install("duckdb")
        await micropip.install("transformers-js-py")
        import pyodide
    import duckdb
    return alt, duckdb, math, mo, np, pd, pyodide, running_in_web


@app.cell
def _(mo, pyodide, running_in_web):
    if running_in_web:
        if not pyodide.ffi.can_run_sync():
            mo.status.toast(
                "Your browser does not support running this notebook online."
            )
            mo.status.toast("Use an updated version of Chrome.")
            mo.status.toast(
                "If you use firefox you need to set the javascript.options.wasm_js_promise_integration flag in about:config to true."
            )
            mo.status.toast("Otherwise run this notebook offline instead.")
            mo.stop()
        else:
            mo.status.toast(
                "If you see this message your browser supports the JSPI feature needed for this notebook."
            )
    return


@app.cell
def _(mo, pd):
    data_path = mo.notebook_location() / "public" / "data.csv"
    data_df = pd.read_csv(str(data_path))
    return (data_df,)


@app.cell
def _(data_df, mo):
    dataset = mo.sql(
        f"""
        DROP TABLE IF EXISTS documents;
        CREATE TABLE IF NOT EXISTS documents as 
            SELECT cast(id as integer) as id, title, text from data_df;
        """
    )
    return


@app.cell
def _():
    # data obtained via
    # df1 = mo.sql(
    #     f"""
    #     select * from 'hf://datasets/wikimedia/wikipedia/20231101.en/train-00010-of-00041.parquet' limit 100;
    #     """
    # )
    return


@app.cell(hide_code=True)
def _(mo):
    mo.md(r"""
    # (Neural-) Information Retrieval
    > Intelligent Agents Lecture Week 7<br> Lecturer: Noel Danz
    """)
    return


@app.cell(hide_code=True)
def _(mo):
    mo.md(rf"""
    ## Motivating Example:

    Imagine we have a set of documents $D$ that (when tokenized) is **much** larger than the allowed context length of our LLM, think of a complete book, a series of webpages, or even collections of business documents that are created in a company.
    Ideally, we would like to pull relevant information from **all** of the documents.

    To do this we yank a strategy humans employ, automate it, and then enhance our LLM with it: search, filter, then answer.
    This approach leads to the idea of "Retrieval Augmented Generation (RAG)" we will see next week.

    This weeks lecture will deal with "how do we search for **relevant** documents" - a rich research field called "information retrieval".
    """)
    return


@app.cell
def _():
    # ============
    # FUNDAMENTALS
    # ============
    return


@app.cell
def _(mo):
    mo.md(r"""
    /// admonition | What is Information Retrieval?
    finding **relevant documents** from a large collection **in response to a query**
    ///

    Examples:
    - Search Engines (DuckDuckGo, Google)
    - _any_ search on a webpage
    - MacOS Spotlight / Windows Search
    """)
    return


@app.cell
def _(duckdb):
    # ======
    # Demo 1
    # ======
    import re


    def tokenize1(text):
        return re.findall(r"[a-zA-Z0-9]+", text.lower())


    try:
        duckdb.remove_function("tokenize1")
    except:
        pass
    finally:
        duckdb.create_function(
            "tokenize1", tokenize1, parameters=["VARCHAR"], return_type="VARCHAR[]"
        )
    return


@app.cell
def _(documents, mo):
    _df = mo.sql(
        f"""
        drop table if exists inverse_doc_index;
        create table if not exists inverse_doc_index as
            WITH tokens AS (
            SELECT
                id,
                unnest(tokenize1(text)) AS token
            FROM documents
        )
        SELECT
            token,
            array_agg(id ORDER BY id) AS doc_ids
        FROM tokens
        GROUP BY token;
        """
    )
    return


@app.cell
def _(inverse_doc_index, mo):
    _df = mo.sql(f"""select * from inverse_doc_index;""")
    mo.md(rf"""
    - traditional **IR** relies on matching _query terms_ to _documents_
    - this necessitates efficient data structures like inverted indices
    - Inverted Indices: maps each _term_ to the list of _documents containing it_

    Example of an inverted index:
    {mo.ui.table(_df)}
    """)
    return


@app.cell
def _(mo):
    idi_query = mo.ui.text(label="Query:", value="citizens")
    None
    return (idi_query,)


@app.cell
def _(idi_query, inverse_doc_index, mo):
    idi_query_res = mo.sql(
        f"""select * from inverse_doc_index where token='{idi_query.value}';"""
    )
    mo.md(f"""
        Try searching the inverted index below.
        ///admonition | Note
        Notice how the example does not allow to search for more than one word, and how different ways of writing the word "break" the search.
        ///
        {idi_query}
        {mo.ui.table(idi_query_res)}
    """)
    return


@app.cell
def _():
    # ========================
    # INDEXING & PREPROCESSING
    # ========================
    return


@app.cell
def _(mo):
    mo.output.append(
        mo.md(r"""
    ## Building a Better Search Index
    """)
    )
    mo.output.append(
        mo.mermaid("""
    flowchart LR
        cd(collecting documents) --> tt
        subgraph tt [tokenizing text]
            direction LR
            normalization-->
            stemming/lemmatization-->
            tokenization
        end
        tt --> ci
        ci(create index)
    """)
    )
    mo.output.append(
        mo.md(r"""
    - preprocessing improves matching _considerably_
    - inverted indices dramatically speed up search vs. full-text search
      - if you have heard about it: fast high quality text search is one of the killer features in HANA/S4
    """)
    )
    return


@app.cell
async def _(duckdb, mo):
    # ==============================================================
    # DEMO: Making a Noramlized & Stemmed & Tokenized Inverted Index
    # ==============================================================
    import nltk
    from nltk.tokenize import word_tokenize
    from nltk.stem import PorterStemmer
    import string

    # Make sure tokenizers are downloaded
    punkt_tab_loc = str(mo.notebook_location() / "public" / "punkt.zip")
    if punkt_tab_loc.startswith("http"):
        from js import fetch
        from pathlib import Path
        import zipfile

        response = await fetch(punkt_tab_loc)
        js_buffer = await response.arrayBuffer()
        py_buffer = js_buffer.to_py()  # this is a memoryview
        stream = py_buffer.tobytes()  # now we have a bytes object

        d = Path("/nltk_data/tokenizers")
        d.mkdir(parents=True, exist_ok=True)
        Path("/nltk_data/tokenizers/punkt.zip").write_bytes(stream)

        # extract punkt_tab.zip
        zipfile.ZipFile("/nltk_data/tokenizers/punkt.zip").extractall(
            path="/nltk_data/tokenizers/"
        )
    else:
        nltk.data.path.append(punkt_tab_loc)


    def tokenize2(text):
        # Normalize: lowercase
        text = text.lower()
        # Remove punctuation
        text = text.translate(str.maketrans("", "", string.punctuation))
        # Tokenize: Splitting into Words
        tokens = word_tokenize(text)
        # Stemming: Normalizing the Words into their "raw" form.
        stemmer = PorterStemmer()
        stemmed_tokens = [stemmer.stem(token) for token in tokens]
        return stemmed_tokens


    try:
        duckdb.remove_function("tokenize2")
    except:
        pass
    finally:
        duckdb.create_function(
            "tokenize2", tokenize2, parameters=["VARCHAR"], return_type="VARCHAR[]"
        )
    None  # NO OUTPUT
    return (tokenize2,)


@app.cell
def _(documents, mo):
    _df = mo.sql(
        f"""
        drop table if exists inverse_doc_index2;
        create table if not exists inverse_doc_index2 as
            WITH tokens AS (
            SELECT
                id,
                unnest(tokenize2(text)) AS token
            FROM documents
        )
        SELECT
            token,
            array_agg(id ORDER BY id) AS doc_ids
        FROM tokens
        GROUP BY token;
        """
    )
    return


@app.cell
def _(mo):
    idi2_query = mo.ui.text(value="Hello World", label="Query: ")
    None  # NO OUTPUT
    return (idi2_query,)


@app.cell
def _(idi2_query, mo, tokenize2):
    idi2_query_res = mo.sql(
        f"""select * from inverse_doc_index2 where token in {tokenize2(idi2_query.value)};""",
    )
    mo.md(f"""
        Play around with the improved search index below:
        ///admonition | Note
        Notice how now you can not only search for multiple words, the search is more robust against how words are written.<br>
        Compare `run` and `running` as an example.
        ///
        {idi2_query}
        {mo.ui.table(idi2_query_res)}
    """)
    return


@app.cell
def _():
    # =======================
    # TERM WEIGHING & RANKING
    # =======================
    return


@app.cell
def _(mo):
    mo.md(r"""
    # Coming up with better indices.

    Currently our indices are always assuming that occurence of a word is _the same as_ relevancy of the document.
    This makes relevance a binary measure, which does not align with how we would think about texts being relevant.

    The following approaches introduce ideas, that try to align the notion of _"relevance"_ to something we would consider both more fitting and better suited for retrieval.
    """)
    return


@app.cell
def _(mo):
    mo.md(r"""
    ## TF - Term Frequency

    /// note | Idea
    If a term occurs _frequently_ in a document it should be more relevant.<br>
    Some sources also regularize by the length of the document.
    ///


    $$
    tf(t,D) = \frac{f(t,D)}{max_{t'\in D}f(t',D)}
    $$
    """)
    return


@app.cell
def _(mo):
    mo.md(r"""
    ## TF-IDF
    > **T**erm **F**requency - **I**nverse **D**ocument **F**requency

    /// note | Idea
    A term should be _more_ relevant if it is more specialized (i.e.: occurs in fewer documents overall).
    ///

    $$tf\cdot idf$$

    $$
    idf(t) = ln(\frac{N}{\sum_{D:t\in D}1})
    $$
    """)
    return


@app.cell
def _(mo):
    t_slide = mo.ui.slider(start=1, stop=200, label="t", show_value=True, value=20)
    theta_slide = mo.ui.slider(
        start=10.0, stop=200.0, label="θ", show_value=True, value=25.0
    )
    return t_slide, theta_slide


@app.cell
def _(alt, mo, np, pd, t_slide, theta_slide):
    k = 1.5
    idf = 0.05
    t = t_slide.value
    theta = theta_slide.value
    b = 0.75
    by_theta = (
        lambda theta: idf
        * (-t * (k + 1) * k * b)
        / ((t + k * (1 - b + b * theta)) ** 2)
    )  # more punishment to documents that are longer
    by_t = (
        lambda t: idf
        * (k * (k + 1) * (1 - b + b * theta))
        / ((t + k * (1 - b + b * theta)) ** 2)
    )  # see that increasing t has diminishing returns for the derivative
    source_theta = pd.DataFrame(
        {
            "θ": np.arange(1.0, 400.0, 1.0),
            "F": map(by_theta, np.arange(1.0, 400.0, 1.0)),
        }
    )
    source_t = pd.DataFrame(
        {
            "t": np.arange(1.0, 400.0, 1.0),
            "F": map(by_t, np.arange(1.0, 400.0, 1.0)),
        }
    )
    bm25_plot = mo.ui.altair_chart(
        alt.hconcat(
            alt.Chart(source_theta)
            .mark_line()
            .encode(x="θ", y="F")
            .properties(title="∂BM25/∂θ"),
            alt.Chart(source_t)
            .mark_line()
            .encode(x="t", y="F")
            .properties(title="∂BM25/∂t"),
        )
    )
    None  # NO OUTPUT
    return (bm25_plot,)


@app.cell
def _(mo):
    mo.md(r"""
    ## BM25
    > "Best Matching 25" ([explainer video](https://www.youtube.com/watch?v=ruBm9WywevM))

    A very well established approach of improving on TF-IDF – if you come up with a novel way of doing retrieval this is what it has to beat **at minimum**.

    /// hint | TF-IDF Problem 1
    Assume Doc $A$ has $tf=1$, $|A|=10$ and Doc $B$ has $tf=10$, $|B|=1000$, with same $idf$.<br>
    Then $tfidf(A)<<tfidf(B)$.<br>
    Even though we would consider $A$ to be "more concise" and should get a higher score.
    ///

    /// hint | TF-IDF Problem 2
    Keyword Stuffing. When adding a bunch of the keyword to the end of the text, we improve our score.<br>
    We would rather have a diminishing return on the improvement per each occurence.
    ///

    /// note | Idea
    1. Improve TF such that it has diminishing returns.
    2. Penalize relevancy by length of document.
    ///

    $$
    s(D,Q) = \sum_{i=1}^n {
    IDF(q_i) \cdot
    \frac{f(q_i,D)\cdot (k_1+1)}{f(q_i,D)+k_1(1-b+b\theta)}
    }
    $$

    with $IDF(q_i)$ being the inverse document frequency, $f(q_i,D)$ being the number of times $q_i$ appears in $D$.
    Also $\theta = \cdot\frac{|D|}{\overline{L_D}}$ with $\overline{L_D}$ being the average document length.
    """)
    return


@app.cell
def _(bm25_plot, mo, t_slide, theta_slide):
    mo.md(rf"""
    Let's see if we were able to fix our issues:

    {t_slide}{theta_slide}
    {bm25_plot}

    Notice how in the $\partial t$ plot the change in score (added to the relevancy score) is sinking over the number of occurrences of $t$. <br>
    Notice how in the $\partial \theta$ plot the change in score is approaching zero when the average document length is increasing, i.e., that the relevancy score gets re-weighed by how much of the overall document a term makes up.
    """)
    return


@app.cell
def _():
    # ==============================
    # RETRIEVAL EVALUATION & METRICS
    # ==============================
    return


@app.cell
def _(mo):
    mo.md(r"""
    ## How to measure Retrieval Success:

    Common metrics: precision $Pr$, recall $Rc$, mean-reciprocal-rank $mRR$.<br>
    Also used: mean-average-precision $mAP$, normalized discounted cumulative gain $nDCG$.

    |Metric | Def. | Explanation |
    |:------|:----:|:------------|
    | Precision | $\frac{\vert y\land \hat{y}\vert}{\vert\hat{y}\vert}$ | How many of my positives, were actual positives.|
    |Recall| $\frac{\vert\hat{y}\vert}{\vert y\vert}$ |How many of the actual positives did I catch?|
    |Mean Reciprocal Rank| $\frac{1}{N}\sum_{i=1}^N{\frac{1}{rank_i}}$ | What is the average rank of the _correct_ retrival.|
    """)
    return


@app.cell
def _():
    # ============================
    # NEURAL INFORMATION RETRIEVAL
    # ============================
    return


@app.cell
def _(mo):
    mo.md(r"""
    ## Neural Information Retrieval

    /// hint | Problem: The vocabulary gap.
    Words like dog, puppy, and dalmatian are related.<br>
    Our current approaches consider these terms to be _semantically independent_.
    ///

    /// note | Idea
    Instead if using term weights based on occurrance frequency, we could _learn_ what is relevant, based on _semantic closeness_.
    ///
    """)
    return


@app.cell
def _(mo):
    mo.md(r"""
    ## Semantic Closeness

    /// note | Idea
    We learn a _vector_ that best represents what the piece of text is talking about.
    Sometimes referred to as the texts "topic".<br>
    This procedure is called "embedding".
    ///

    Luckily modern language modelling approaches have this representation internally "by default":
    - TF-IDF with SVD: By calculating a rank-reduced representation of a TF-IDF matrix, we can obtain useful embeddings
      ([sklearn-example](https://scikit-learn.org/stable/auto_examples/text/plot_document_clustering.html#sphx-glr-auto-examples-text-plot-document-clustering-py)).
    - LSTMs, GRUs: The internal hidden state is the "same" as the topic of the text.
    - Transformers: The representation right before the `lm_head` is often used as an embedding.

    /// note
    We assume that to be able to continue a piece of text, the model needs to be able to accurately represent what has been talked about before – which also holds in practice.
    ///
    """)
    return


@app.cell
def _(mo):
    mo.md(r"""
    ## Semantic Closeness: Measuring

    /// note |
    We need to be able to measure the similiarity of two vectors representing topics.
    ///

    We will assume that similarity is equivalent to "closeness", meaning some **measure of distance** within the embedding space.

    Commonly used measures:
    - $L_2$-distance $L_2(a,b) = \sqrt{\sum_N(a_i-b_i)^2}$
    - dot-product $a\cdot b = \sum{a_ib_i}$, how directionally aligned and similar in magnitude two vectors are
    - cosine similiarity $S_c(a,b) = \frac{a\cdot b}{||a||||b||} = cos(\theta)$, how directionally similar two vectors are independent of magnitude
    """)
    return


@app.cell
async def _(duckdb, mo, np, pyodide, running_in_web):
    # load tokenizer, model based on web / native
    if running_in_web:
        from transformers_js_py import import_transformers_js

        transformers = await import_transformers_js()
        pipeline = transformers.pipeline

        pipe = await pipeline("feature-extraction", "Xenova/all-MiniLM-L6-v2")

        if not pyodide.ffi.can_run_sync():
            mo.status.toast(
                "Your browser does not support running this notebook online."
            )
            mo.status.toast("Use an updated version of Chrome.")
            mo.status.toast(
                "If you use firefox you need to set the javascript.options.wasm_js_promise_integration flag in about:config to true."
            )
            mo.status.toast("Otherwise run this notebook offline instead.")

        def embed_text(text: str):
            out = pyodide.ffi.run_sync(
                pipe(text, {"pooling": "mean", "normalize": "true"})
            )
            return np.squeeze(out.to_numpy()).tolist()
    else:
        from transformers import AutoTokenizer, AutoModel
        import torch

        tokenizer = AutoTokenizer.from_pretrained(
            "sentence-transformers/all-MiniLM-L6-v2"
        )
        model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

        def embed_text(text: str):
            """Return a BERT embedding vector for a piece of text."""
            inputs = tokenizer(text, return_tensors="pt", truncation=True)
            with torch.no_grad():
                outputs = model(**inputs)
            # CLS embedding (batch_size=1, hidden_size=768)
            return outputs.pooler_output[0].numpy().tolist()


    try:
        duckdb.remove_function("embed_text")
    except:
        pass
    finally:
        duckdb.create_function(
            "embed_text",
            embed_text,
            parameters=["VARCHAR"],
            return_type="DOUBLE[]",
        )
    None  # NO OUTPUT
    return (embed_text,)


@app.cell
def _(documents, mo):
    _df = mo.sql(
        f"""
        drop table if exists doc_embeddings;
        create table if not exists doc_embeddings as
            SELECT
                id,
                embed_text(text) AS emb
            FROM documents;
        """
    )
    return


@app.cell
def _(duckdb, math):
    def dot(a, b):
        # neg. of dot-product to be able to use the same ascending sort
        return -sum(float(x) * float(y) for x, y in zip(a, b))


    def cossim(a, b):
        a = [float(x) for x in a]
        b = [float(x) for x in b]

        num = math.fsum(x * y for x, y in zip(a, b))
        den_a = math.sqrt(math.fsum(x * x for x in a))
        den_b = math.sqrt(math.fsum(x * x for x in b))

        return 1 - (num / (den_a * den_b)) if den_a and den_b else None


    def l2(a, b):
        # we can skip the sqrt as we just _order_ distances
        return sum((float(x) - float(y)) ** 2 for x, y in zip(a, b))


    try:
        duckdb.remove_function("dot")
    except:
        pass
    finally:
        duckdb.create_function("dot", dot, return_type="DOUBLE")

    try:
        duckdb.remove_function("cossim")
    except:
        pass
    finally:
        duckdb.create_function("cossim", cossim, return_type="DOUBLE")

    try:
        duckdb.remove_function("l2")
    except:
        pass
    finally:
        duckdb.create_function("l2", cossim, return_type="DOUBLE")
    return


@app.cell
def _(mo):
    emb_query = mo.ui.text(value="Helen Bernstein High School", label="Query: ")
    dist_func = mo.ui.radio(
        options={"cossim": "cossim", "-1·dot-product": "dot", "euclidean": "l2"},
        inline=True,
        value="cossim",
    )
    return dist_func, emb_query


@app.cell
def _(dist_func, emb_query, embed_text, mo):
    embedded_query = embed_text(emb_query.value)
    emb_query_res = mo.sql(
        f"""
        with closest 
        as (
        SELECT id, {dist_func.value}(emb, {embedded_query}::DOUBLE[]) AS dist
        FROM doc_embeddings
        ORDER BY dist ASC
        )
        select closest.dist, documents.id, title, text from documents inner join closest on documents.id=closest.id
        order by closest.dist asc
        """
    )
    mo.md(f"""
    Play around a bit with the embedding based search below:

    {emb_query}
    {dist_func}
    {mo.ui.table(emb_query_res)}

    ///note 
    You can search with arbitrary strings, and the search results will be noticeably closer to the meaning of the query than when we used an inverted index.
    ///
    """)
    return


if __name__ == "__main__":
    app.run()
