Skip to content

kddubey/backprompt

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

backprompt

backprompt provides a data structure which allows a user to dynamically construct prompts while avoiding repeated LLM computations.

Motivation

In many large-scale tasks performed by LLMs, a particular prompt is used many times—once for each instance of the task. In cases like these, the amount of computation performed by future LLM calls can be reduced by caching and re-using the LLM's representation of the prompt.

backprompt takes this well-known idea a step further by additionally caching LLM representations of intermediate text in the prompt. Intermediate caching may be useful when one needs to dynamically adjust the prompt without having to re-compute the LLM's representation of it. backprompt abstracts the complex process of prompt construction and caching as plain-old string concatenation.

Usage

See the notebook demos/minimal_example.ipynb for a more realistic use case. Here's a toy demo:

from transformers import AutoModelForCausalLM, AutoTokenizer
from backprompt import Text

# Load a GPT model and its tokenizer
model_name = 'gpt2'
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
mt = (model, tokenizer)

# Wrap strings in Text and construct them via concatenation
context = Text('Hello there.', mt)
choices = [Text(' Senator', mt), Text(' General', mt)]
endings = [Text(' Amidala', mt), Text(' Kenobi...', mt)]

texts = [context + choice + ending for choice in choices for ending in endings]
print(texts[-1].string)
# Hello there. General Kenobi...

# Get next-token logits by calling every text obj
# The punchline is that you don't have to worry about repeated computation
for text in texts:
    text()

texts[-1].model_repr[1].logits[:, -1, :]

Installation

python -m pip install git+https://github.com/kddubey/backprompt.git

How it works

If you basically know how backprop works (watch this YouTube video), and you basically know how a decoder-only autoregressive language model works (watch this YouTube video), then you know how backprompt works :-)

Analogies:

  • backprop → "intermediate" gradient of a function
    backprompt → attention block keys and values.
  • backprop → gradient of a function
    backprompt → token logits.
  • backprop → chain rule
    backprompt → tensor concatenation.

TODO: graph visualization

Testing

TODO: expand test cases

pytest

Todos

Research
  • What's the computational complexity of using past keys and values wrt # tokens?
  • Do few-shot prompts exhibit interesting independencies? If so, one could construct prompts using different examples on the fly.
Code
  • Expand tests
    • More autoregressive LMs
    • More string breakdowns
  • Graph visualization
  • Allow for frozen representations / custom independencies in the graph
  • Batching
  • Eager mode
  • ModelRepr dataclass for convenience
    • Add and update a token_logprobs attribute to the LM output obj
    • By default, only keep last (non-pad) token's logits in the LM output obj
  • Documentation?

About

Lazy KV cache data structure for language models

Topics

Resources

License

Stars

Watchers

Forks

Languages