Building Smart Medical Diagnostics: LLM Auto-Prompt & Chaining with DSPy
Introduction
So I was reading papers on prompting when I came across the paper https://arxiv.org/pdf/2310.03714 — Compiling declarative language model calls into self-improving pipelines. The framework piqued my interest because, until a few months ago, prompt engineering was all the rage. The job market was buzzing with roles for prompt engineers. However, this trend has shifted. Prompt engineering was neither an art nor a science but rather akin to the Clever Hans phenomenon — humans providing the necessary context for systems to respond more effectively.
Books and blogs sprouted, claiming to provide the “Top 50 prompts to get the best out of GPT,” among other grand promises. However, large-scale experiments have shown that there is no single prompt or strategy that works for all problems. Some prompts may seem better in isolation but turn out to be hit-and-miss when analyzed comprehensively.
So today, we are going to explore DSPy: a framework developed by Stanford for compiling declarative language model calls into self-improving pipelines. This innovative approach treats LLMs as modules optimized by a compiler, similar to the abstractions found in PyTorch.
Why a Medical Diagnosis Example?
Medical diagnosis is a complex task that requires detailed reasoning and precise decision-making. This makes it an ideal domain to demonstrate the capabilities of DSPy’s Chain-of-Thought (CoT) reasoning. The CoT approach breaks down the diagnostic process into intermediate steps, leading to a final diagnosis. This not only mirrors the logical thought process of medical professionals but also leverages the strengths of LLMs to provide accurate and efficient diagnoses.
In this guide, we will walk through the steps to create a medical diagnosis application using DSPy. We will cover:
- Setting up the environment
- Defining the task with DSPy signatures
- Creating and optimizing modules for intermediate diagnostic steps and final diagnoses
- Comparing results before and after optimization
By the end of this tutorial, you will have a comprehensive understanding of how to use DSPy to enhance medical diagnostics through automated reasoning chains.
Understanding Prompting Libraries
To understand where DSPy fits into the landscape of prompting libraries, it’s helpful to look at the different types of libraries available for working with large language models (LLMs):
Prompt Wrappers:
These libraries provide a minimal level of abstraction for creating and managing prompts. They typically offer simple tools for string insertion and extraction, allowing users to quickly generate prompts without much overhead.
Example libraries: MiniChain
Application Development Libraries:
These high-level libraries are designed for building applications with LLMs. They come with pre-built modules and tools that help developers quickly configure and integrate LLM capabilities into their applications. These libraries abstract away much of the complexity involved in working with LLM APIs.
Example libraries: LangChain, LlamaIndex
Generation Control Libraries:
These libraries focus on controlling the output of LLMs. They provide tools for implementing control flows, enforcing specific output schemas (such as JSON), and constraining the model’s outputs to match certain patterns or regular expressions.
Example libraries: Guidance, LMQL, RELM, Outlines
Prompt Generation & Automation:
Similar to traditional machine learning, these libraries define inputs and targets along with high-level operators or building blocks. They optimize or generate the specifics for each prompt and stage, automating much of the prompt development process.
Example libraries: DSPy
The Role of DSPy
Prompt development and LLM chaining often require extensive trial and error. This process involves not only developing effective prompts but also composing them into discrete tasks that LLMs can handle efficiently. DSPy simplifies this process by providing a framework for developing higher-level tasks that can self-optimize and evaluate their own performance.
With DSPy, you can:
Define complex tasks: Use DSPy’s signatures and modules to break down complex tasks into manageable sub-tasks.
Automate prompt generation: DSPy can automatically generate and optimize prompts for different stages of a task, reducing the need for manual tweaking.
Optimize performance: DSPy’s teleprompters and compilers work together to improve the performance of your LLM-based applications, ensuring that they adapt and improve over time.
Now,
In the following sections, we will walk through the steps to create a medical diagnosis application using DSPy. By the end of this tutorial, you will understand how to use DSPy to build sophisticated AI systems that can handle complex reasoning tasks in the medical field.
Here’s the link to the Full Colab NB — https://colab.research.google.com/drive/1I1S8xC46NIadh2nAFdeZWA09Y-B8uSH1?usp=sharing
Let’s go section by section
To get started with DSPy, you’ll need to set up your environment.
%load_ext autoreload
%autoreload 2
import sys
import os
try: # When on Google Colab, let's clone the notebook so we download the cache.
import google.colab
repo_path = 'dspy'
!git -C $repo_path pull origin || git clone https://github.com/stanfordnlp/dspy $repo_path
except:
repo_path = '.'
if repo_path not in sys.path:
sys.path.append(repo_path)
# Set up the cache for this notebook
os.environ["DSP_NOTEBOOK_CACHEDIR"] = os.path.join(repo_path, 'cache')
import pkg_resources # Install the package if it's not installed
if not "dspy-ai" in {pkg.key for pkg in pkg_resources.working_set}:
!pip install -U pip
!pip install dspy-ai
!pip install openai~=0.28.1
import dspy
Configure DSPy:
After installing DSPy, configure it for your project. This may include setting up API keys for language models and configuring other environment settings.
# Configure your OpenAI API key
import openai
openai.api_key = 'your-api-key'
# Import DSPy and configure settings
import dspy
Verify Installation:
Ensure that DSPy and all dependencies are installed correctly by running a simple test.
# Verify DSPy installation by running a simple test
print(dspy.__version__)
Defining the Task with DSPy
In this section, we will define the task of medical diagnosis using DSPy. We will create signatures that specify the inputs and outputs for each step in the Chain-of-Thought (CoT) reasoning process, including intermediate diagnostic steps and the final diagnosis.
Creating Signatures
Signatures in DSPy define the structure of the tasks, specifying what inputs the language model will receive and what outputs it should produce. For our medical diagnosis example, we will create two main signatures: one for generating intermediate diagnostic steps and one for generating the final diagnosis.
class GenerateDiagnosticStep(dspy.Signature):
"""Generate an intermediate diagnostic step based on symptoms and medical history."""
symptoms = dspy.InputField(desc="patient's symptoms")
medical_history = dspy.InputField(desc="patient's medical history")
diagnostic_step = dspy.OutputField(desc="an intermediate diagnostic step")
class GenerateFinalDiagnosis(dspy.Signature):
"""Generate the final diagnosis using all diagnostic steps."""
symptoms = dspy.InputField(desc="patient's symptoms")
medical_history = dspy.InputField(desc="patient's medical history")
diagnostic_steps = dspy.InputField()
final_diagnosis = dspy.OutputField(desc="the final diagnosis of the patient")
Developing Modules
With the signatures defined, we can now develop the modules that will use these signatures to perform the diagnostic steps and generate the final diagnosis.
MedicalDiagnosisQA Module
This module will use the GenerateDiagnosticStep and GenerateFinalDiagnosis signatures to implement the CoT reasoning process for medical diagnosis.
class MedicalDiagnosisQA(dspy.Module):
def __init__(self, num_steps=3):
super().__init__()
self.generate_step = dspy.ChainOfThought(GenerateDiagnosticStep)
self.generate_final_diagnosis = dspy.ChainOfThought(GenerateFinalDiagnosis)
self.num_steps = num_steps
def forward(self, symptoms, medical_history):
try:
# Step 1: Generate intermediate diagnostic steps
diagnostic_steps = []
current_context = medical_history
for step_num in range(self.num_steps):
step = self.generate_step(symptoms=symptoms, medical_history=current_context).diagnostic_step
diagnostic_steps.append(step)
current_context += " " + step
logger.info(f"Step {step_num + 1}: {step}")
# Step 2: Generate the final diagnosis using all diagnostic steps
final_diagnosis = self.generate_final_diagnosis(symptoms=symptoms, medical_history=current_context, diagnostic_steps=" ".join(diagnostic_steps)).final_diagnosis
logger.info(f"Final Diagnosis: {final_diagnosis}")
return dspy.Prediction(final_diagnosis=final_diagnosis, diagnostic_steps=diagnostic_steps)
except Exception as e:
logger.error(f"An error occurred during diagnosis: {e}")
raise
Training and Validation
To ensure our module performs accurately, we will create a small dataset for training and validation. This dataset will include examples of symptoms, medical history, intermediate diagnostic steps, and final diagnoses.
# Example training data
train_data = [
dspy.Example(symptoms="fever, cough, fatigue", medical_history="patient has a history of asthma", diagnostic_step="consider respiratory infections", final_diagnosis="influenza"),
dspy.Example(symptoms="chest pain, shortness of breath", medical_history="patient has a history of high blood pressure", diagnostic_step="consider cardiovascular issues", final_diagnosis="angina"),
dspy.Example(symptoms="headache, nausea, dizziness", medical_history="patient recently suffered a minor head injury", diagnostic_step="consider neurological evaluation", final_diagnosis="concussion")
]
def validate_final_diagnosis(example, pred, trace=None):
return example.final_diagnosis.lower() == pred.final_diagnosis.lower()
Comparing Results Before and After Compilation
Finally, we will compare the performance of the module before and after compilation using an example case.
Before Compilation
# Example new symptoms and medical history
new_symptoms = "persistent cough, weight loss, night sweats"
new_medical_history = "patient has a history of smoking"
# Get the prediction using the uncompiled module
uncompiled_medical_diagnosis_qa = MedicalDiagnosisQA()
uncompiled_prediction = uncompiled_medical_diagnosis_qa(symptoms=new_symptoms, medical_history=new_medical_history)
# Print the input and the prediction
print(f"Symptoms (Uncompiled): {new_symptoms}")
print(f"Medical History (Uncompiled): {new_medical_history}")
print(f"Predicted Final Diagnosis (Uncompiled): {uncompiled_prediction.final_diagnosis}")
print(f"Intermediate Diagnostic Steps (Uncompiled): {uncompiled_prediction.diagnostic_steps}")
After Compilation
# Get the prediction using the compiled module
compiled_prediction = compiled_medical_diagnosis_qa(symptoms=new_symptoms, medical_history=new_medical_history)
# Print the input and the prediction
print(f"Symptoms (Compiled): {new_symptoms}")
print(f"Medical History (Compiled): {new_medical_history}")
print(f"Predicted Final Diagnosis (Compiled): {compiled_prediction.final_diagnosis}")
print(f"Intermediate Diagnostic Steps (Compiled): {compiled_prediction.diagnostic_steps}")
Summary and Analysis
By following these steps, DSPy helps optimize the language model for complex medical reasoning tasks, resulting in improved accuracy and efficiency. This demonstrates the power of DSPy’s approach to handling sophisticated AI tasks beyond simple prompt-based methods.
Thanks for reading! If you found this article useful, please leave a comment or a clap.
Follow me to stay updated on my latest articles