RC Boardroom Chronicles: AI Applications in the Legal Industry

#AI#NLP#legal-tech#LLaMA#business-case#fine-tuning

Recently, I participated in the Boardroom competition hosted by the University of Toronto's Rotman Commerce, a business case competition. This year's topic was about a law firm looking to introduce AI to improve work efficiency. As soon as I got the topic, I knew we had this in the bag. After careful analysis, I wrote a 2,400-word project draft on the first night. To analyze it briefly, for a law firm to use AI, the only choice is NLP. Think about it - lawyers don't need computer vision, OCR has been commonplace for ages (who doesn't know how to use scanning software now?), and numerical processing is even more distant. Since the theme was set, we needed to think about where NLP could be applied. We divided the entire project into two parts: internal and external. The internal part mainly focuses on optimizing production processes, such as document organization and lawyer assistants; the external part mainly targets clients, with the simplest being building an LLM for client consultation. The detailed solution is:

  1. Document Management System (Internal)
    • Automatic Classification and Tagging: An AI document management system using Natural Language Processing (NLP) and machine learning algorithms can automatically identify, classify, and tag various legal documents such as contracts, case files, and transcripts. The system can understand document content, extract key information, and organize documents into predefined categories. This not only reduces the time spent manually organizing files but also improves file management accuracy.
    • Quick Retrieval: Lawyers can quickly retrieve required files by keywords, tags, dates, etc., improving work efficiency. This efficient retrieval function enables lawyers to quickly find relevant files when needed, supporting their work.
    • File Template Generation: AI can generate standardized legal documents such as contracts, complaints, and mediation agreements, automatically filling in relevant information like party names and case numbers, reducing lawyers' workload in writing documents. This not only improves document generation efficiency but also ensures consistency and accuracy of document content.
  2. Smart Contracts & Customer Service (External)
    • Automated Customer Consultation: AI can handle preliminary consultations with prospective clients, providing standardized answers. This not only saves lawyers' time but also ensures every client receives a timely response. Through automated customer service, lawyers can focus on more complex legal issues rather than spending significant time on repetitive client consultations.
    • Smart Contract Generation: AI can generate standardized contracts based on customer needs, meeting clients' legal requirements. For clients who need standardized contracts, AI can quickly generate agreements, reducing lawyers' time investment in repetitive work.
  3. Compliance Review (External)
    • Automatic Review and Flagging: AI can automatically review contract terms, flag any non-compliant, unfair, or ambiguous clauses, and alert lawyers. This preliminary automated screening can help lawyers quickly identify potential risk points, thus conducting compliance reviews more efficiently.
    • Preliminary Screening: AI can quickly complete preliminary document screening, submitting key issues to lawyers for further analysis. This automated preliminary screening not only saves time but also improves review accuracy.
    • Compliance Review: AI can quickly identify passages that don't comply with various laws and regulations, annotate specific non-compliant legal documents, and prompt lawyers to make modifications. This greatly reduces lawyers' workload in document review, especially when facing constantly updated legal provisions and laws from different countries.
  4. Lawyer Assistant (External)
    • Schedule Management: AI can help lawyers arrange and manage schedules, reminding them of upcoming meetings and deadlines. This intelligent schedule management not only improves time management efficiency but also ensures lawyers don't miss important time points.
    • File Finding: AI can quickly find and organize required files, reducing lawyers' time on file management. This efficient file-finding function enables lawyers to invest more time in actual legal work rather than spending it on searching for files.

This is an excerpt from my initial draft. In fact, by the final draft, it was basically the same framework. This shows that almost all business functions are based on NLP. Therefore, we decided to make a demo, and I first created this LLM demo for serving external clients.

Actually, making this model isn't complicated. In short, it's fine-tuning based on existing large models using law-related datasets. We chose llama3-8B as our base model for two reasons: first, it's currently the strongest model at this parameter level; second, compared to other solutions, LLaMA has a more complete and larger community, with more mature fine-tuning solutions available. We ultimately chose the LLaMA-Factory project as our fine-tuning solution.

The next step was choosing a dataset. My initial idea was to scrape some Q&A from legal forums, but after searching around, I found that most legal forums in the US and Canada have users answering based on their own experience - correctness aside, the language also isn't rigorous enough. I ultimately chose two datasets from HuggingFace: dzunggg/legal-qa-v1, ibunescu/qa_legal_dataset_train, and coastalcph/lex_glue. I used the entire first dataset because it contains very professional Q&A - the questions are genuine (realistic, commonly encountered in life), and the answers are serious (with references and rigorous wording). The second dataset is larger but lower quality; I suspect it was scraped from some legal forum. The third dataset has the best quality, divided into several parts: case_hold matches case descriptions with corresponding legal provisions, rigorously stating charges; scotus contains US Supreme Court documents; unfair_tos divides "terms of service" from 50 software applications into individual sentences, marking those containing unfair clauses. I didn't use this dataset because it doesn't have life-relevant Q&A - everything is in document format, extremely formal, but it's undeniably one of the top-quality NLP legal datasets. After downloading the dataset using datasets and processing it with LLaMA-Factory, training can begin.

import json
import re
from datasets import load_dataset

# Load dataset
dataset = load_dataset("dzunggg/legal-qa-v1")

# Define a function to remove "Q:" and "A:" prefixes
def remove_prefix(example):
    if example['question'].startswith('Q:'):
        example['question'] = example['question'][2:].strip()
    if example['answer'].startswith('A:'):
        example['answer'] = example['answer'][2:].strip()
    return example

# Function to remove escape characters and links
def clean_text(text):
    # Remove escape characters
    text = text.replace('\n', ' ').replace('\r', ' ')
    # Remove links
    text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
    return text

# Process dataset
dataset = dataset.map(remove_prefix)

# Generate a general instruction
def generate_instruction():
    return "Please provide detailed answers to the following legal questions."

# Convert dataset to new format
def convert_to_new_format(example):
    question = clean_text(example['question'])
    answer = clean_text(example['answer'])
    return {
        "instruction": generate_instruction(),
        "input": question,
        "output": answer
    }

# Apply conversion function
new_dataset = {}
new_dataset['train'] = [convert_to_new_format(example) for example in dataset['train']]

# Save dataset as JSON file
with open('legal_qa_v1_train.json', 'w') as f:
    json.dump(new_dataset['train'], f, indent=4, ensure_ascii=False)

print("Dataset has been saved to legal_qa_v1_train.json files.")
# LLaMA-Factory required format

import json

%cd /content/LLaMA-Factory/

NAME = "Llama-3"
AUTHOR = "LLaMA Factory"

with open("data/legal_qa_v1_train.json", "r", encoding="utf-8") as f:
  dataset = json.load(f)
for sample in dataset:
  sample["output"] = sample["output"].replace("{{"+ "name" + "}}", NAME).replace("{{"+ "author" + "}}", AUTHOR)
with open("data/legal_qa_v1_train.json", "w", encoding="utf-8") as f:
  json.dump(dataset, f, indent=2, ensure_ascii=False)

Thanks to LLaMA-Factory, the most complex and core fine-tuning step actually became the simplest and shortest. Just use the example code, modify the parameters, and run it with one click. The time depends on the dataset size. If I only use the first dataset mentioned above, 30 epochs take less than half an hour.

# Use WebUI
%cd /content/LLaMA-Factory/
!GRADIO_SHARE=1 llamafactory-cli webui
import json

args = dict(
  stage="sft",                        # do supervised fine-tuning
  do_train=True,
  model_name_or_path="mattshumer/Llama-3-8B-16K", # If you can't use HuggingFace in China, download the model and change to the model path
  dataset="legal_qa_v1_train",             # use alpaca and identity datasets
  template="llama3",                     # use llama3 prompt template
  finetuning_type="lora",                   # use LoRA adapters to save memory
  lora_target="all",                     # attach LoRA adapters to all linear layers
  output_dir="llama3_lora",                  # the path to save LoRA adapters
  per_device_train_batch_size=8,               # the batch size, tested that L20 48G can set batch up to 48
  gradient_accumulation_steps=6,               # the gradient accumulation steps
  lr_scheduler_type="cosine",                 # use cosine learning rate scheduler
  logging_steps=10,                      # log every 10 steps
  warmup_ratio=0.1,                      # use warmup scheduler
  save_steps=1000,                      # save checkpoint every 1000 steps
  learning_rate=1e-4,                     # the learning rate
  num_train_epochs=10.0,                    # the epochs of training
  max_samples=500,                      # use 500 examples in each dataset
  max_grad_norm=1.0,                     # clip gradient norm to 1.0
  quantization_bit=8,                     # use 4-bit QLoRA
  loraplus_lr_ratio=16.0,                   # use LoRA+ algorithm with lambda=16.0
  use_unsloth=True,                      # use UnslothAI's LoRA optimization for 2x faster training
  use_unsloth=False,				# Use unsloth acceleration
  fp16=True,                         # use float16 mixed precision training
  overwrite_output_dir=True,
)

json.dump(args, open("train_llama3.json", "w", encoding="utf-8"), indent=2)

%cd /content/LLaMA-Factory/

!llamafactory-cli train train_llama3.json

The result is a safetensors file of about 80MB, which can be understood as a "patch" for the large model. Using LLaMA-Factory's infer function again, you can quickly see the model's effect. If you're not satisfied, go back to the previous step to retrain or resume training. If it's OK, you can proceed to the next step: merge.

Inference code:

from llamafactory.chat import ChatModel
from llamafactory.extras.misc import torch_gc

%cd /content/LLaMA-Factory/

args = dict(
  model_name_or_path="mattshumer/Llama-3-8B-16K",
  adapter_name_or_path="llama3_lora",            # load the saved LoRA adapters, the address of the "patch" just generated
  template="llama3",                     # same to the one in training
  finetuning_type="lora",                  # same to the one in training
  quantization_bit=8,                    # load 4-bit quantized model, can choose 4 or 8
  use_unsloth=True,                     # use UnslothAI's LoRA optimization for 2x faster generation
)
chat_model = ChatModel(args)

background_prompt = """
Custom background prompt
"""

messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
    query = input("\nUser: ")
    if query.strip() == "exit":
        break
    if query.strip() == "clear":
        messages = []
        torch_gc()
        print("History has been removed.")
        continue

    # Combine the user input with the background prompt
    combined_query = background_prompt + query
    messages.append({"role": "user", "content": combined_query})
    print("\n\nAssistant: ", end="", flush=True)

    response = ""
    for new_text in chat_model.stream_chat(messages,temperature=,max_new_token=): # Custom parameters
        print(new_text, end="", flush=True)
        response += new_text
    print()
    messages.append({"role": "assistant", "content": response})

torch_gc()

Merge, as the name suggests, is to merge the patch with the original large model to get a new large model. This step also uses LLaMA-Factory's merge function and can be completed in about 2 minutes.

import json

%cd /content/LLaMA-Factory

args = dict(
  model_name_or_path="nvidia/Llama3-ChatQA-1.5-8B", # use official non-quantized Llama-3-8B-Instruct model
  adapter_name_or_path="/content/LLaMA-Factory/llama3_lora_11",            # Generated "patch" address
  template="llama3",                     # same to the one in training
  finetuning_type="lora",                  # same to the one in training
  export_dir="llama3_lora_merged",              # the path to save the merged model, output directory
  export_size=2,                       # the file shard size (in GB) of the merged model
  export_device="cuda",                    # the device used in export, can be chosen from `cpu` and `cuda`
  #export_hub_model_id="your_id/your_model",         # the Hugging Face hub ID to upload model
)

json.dump(args, open("merge_llama3.json", "w", encoding="utf-8"), indent=2)

%cd /content/LLaMA-Factory/

!llamafactory-cli export merge_llama3.json

At this point, the entire process is complete. Here's the code for uploading to HuggingFace:

import os
from huggingface_hub import HfApi, HfFolder

api = HfApi()
# token = HfFolder.get_token()
token = 'HuggingFaceToken' # If you've logged in using huggingface-cli login beforehand or set it in environment variables, you can comment this out

model_dir = "/content/LLaMA-Factory/llama3_lora_merged"
repo_id_base = "StevenChen16/llama3-8b-Lawyer" # Repository to upload to

# Create main repository
api.create_repo(repo_id=repo_id_base, token=token, private=False, exist_ok=True) # Note: if you commented out the token earlier, you need to remove the token parameter here

# Traverse model folder and upload each model
for model_name in os.listdir(model_dir):
    model_path = os.path.join(model_dir, model_name)
    if os.path.isfile(model_path):
        path_in_repo = f"{model_name}"
        api.upload_file(
            path_or_fileobj=model_path,
            path_in_repo=path_in_repo,
            repo_id=repo_id_base,
            token=token
        )

LLaMA-Factory repository address (the hero behind all this):

https://github.com/hiyouga/LLaMA-Factory.git

TODO:

Train a model for identifying unfair clauses and compliance review using the unfair_tos dataset from coastalcph/lex_glue.

Make this model support Chinese and fine-tune it with Chinese legal documents.


2024/06/06 Update

Completed the construction and deployment of compliance review.

Still using LLaMA-Factory, but this time the base model used was "princeton-nlp/Llama-3-Instruct-8B-SimPO", which claims to outperform GPT-4. Trained using the unfair_tos portion of the lex_glue dataset. Additionally, I added about 1000 clause judgments myself. I uploaded the dataset to HuggingFace: StevenChen16/unfair_tos. I put the trained model (adapter) on HuggingFace: StevenChen16/llama3-8b-compliance-review-adapter. Note this is an adapter, you still need to call the base model when using it. Below is example code using Gradio:

import gradio as gr
from llamafactory.chat import ChatModel
from llamafactory.extras.misc import torch_gc
import re

def split_into_sentences(text):
    sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s')
    sentences = sentence_endings.split(text)
    return [sentence.strip() for sentence in sentences if sentence]

def process_paragraph(paragraph, progress=gr.Progress()):
    sentences = split_into_sentences(paragraph)
    results = []
    total_sentences = len(sentences)
    for i, sentence in enumerate(sentences):
        progress((i + 1) / total_sentences)
        messages.append({"role": "user", "content": sentence})
        sentence_response = ""
        for new_text in chat_model.stream_chat(messages, temperature=0.7, top_p=0.9, top_k=50, max_new_tokens=300):
            sentence_response += new_text.strip()
        category = sentence_response.strip().lower().replace(' ', '_')
        if category != "fair":
            results.append((sentence, category))
        else:
            results.append((sentence, "fair"))
        messages.append({"role": "assistant", "content": sentence_response})
        torch_gc()
    return results


args = dict(
  model_name_or_path="princeton-nlp/Llama-3-Instruct-8B-SimPO",  # Use quantized Llama-3-8B-Instruct model
  adapter_name_or_path="StevenChen16/llama3-8b-compliance-review-adapter",                 # Load saved LoRA adapter
  template="llama3",                                      # Same as used in training
  finetuning_type="lora",                                 # Same as used in training
  quantization_bit=8,                                     # Load 4-bit quantized model
  use_unsloth=True,                                       # Use UnslothAI's LoRA optimization for faster generation
)
chat_model = ChatModel(args)
messages = []

# Define type to color mapping
label_to_color = {
    "fair": "green",
    "limitation_of_liability": "red",
    "unilateral_termination": "orange",
    "unilateral_change": "yellow",
    "content_removal": "purple",
    "contract_by_using": "blue",
    "choice_of_law": "cyan",
    "jurisdiction": "magenta",
    "arbitration": "brown",
}

with gr.Blocks() as demo:

    with gr.Row(equal_height=True):
        with gr.Column():
            input_text = gr.Textbox(label="Input Paragraph", lines=10, placeholder="Enter the paragraph here...")
            btn = gr.Button("Process")
        with gr.Column():
            output = gr.HighlightedText(label="Processed Paragraph", color_map=label_to_color)
            progress = gr.Progress()

    def on_click(paragraph):
        results = process_paragraph(paragraph, progress=progress)
        return results

    btn.click(on_click, inputs=input_text, outputs=[output])

demo.launch(share=True)