Initial commit: Agent JAE CLI - AI coding assistant (forked from pi-mono by badlogic)
Some checks are pending
CI / build-check-test (push) Waiting to run
Some checks are pending
CI / build-check-test (push) Waiting to run
This commit is contained in:
commit
2d26f44944
739 changed files with 204988 additions and 0 deletions
23
.gitattributes
vendored
Normal file
23
.gitattributes
vendored
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
# Default to LF for text files across the repo
|
||||
* text=auto eol=lf
|
||||
|
||||
# Windows scripts should keep CRLF
|
||||
*.bat text eol=crlf
|
||||
*.cmd text eol=crlf
|
||||
*.ps1 text eol=crlf
|
||||
|
||||
# Shell scripts should keep LF
|
||||
*.sh text eol=lf
|
||||
|
||||
# Common binary assets
|
||||
*.png binary
|
||||
*.jpg binary
|
||||
*.jpeg binary
|
||||
*.gif binary
|
||||
*.webp binary
|
||||
*.ico binary
|
||||
*.pdf binary
|
||||
*.zip binary
|
||||
*.gz binary
|
||||
*.woff binary
|
||||
*.woff2 binary
|
||||
136
.github/APPROVED_CONTRIBUTORS
vendored
Normal file
136
.github/APPROVED_CONTRIBUTORS
vendored
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
# GitHub handles of users approved to submit PRs
|
||||
# One handle per line (without @)
|
||||
# Add new contributors by commenting lgtm on their issue
|
||||
barapa
|
||||
alasano
|
||||
aadishv
|
||||
airtonix
|
||||
aliou
|
||||
aos
|
||||
austinm911
|
||||
banteg
|
||||
ben-vargas
|
||||
butelo
|
||||
can1357
|
||||
CarlosGtrz
|
||||
cau1k
|
||||
cmf
|
||||
crcatala
|
||||
Cursivez
|
||||
cv
|
||||
dannote
|
||||
default-anton
|
||||
dnouri
|
||||
DronNick
|
||||
enisdenjo
|
||||
ferologics
|
||||
fightbulc
|
||||
ghoulr
|
||||
gnattu
|
||||
HACKE-RC
|
||||
hewliyang
|
||||
hjanuschka
|
||||
iamd3vil
|
||||
jblwilliams
|
||||
joshp123
|
||||
jsinge97
|
||||
justram
|
||||
kaofelix
|
||||
kiliman
|
||||
kim0
|
||||
lockmeister
|
||||
LukeFost
|
||||
lukele
|
||||
m-box-mr
|
||||
marckrenn
|
||||
markusylisiurunen
|
||||
mcinteerj
|
||||
melihmucuk
|
||||
mitsuhiko
|
||||
mrexodia
|
||||
nathyong
|
||||
nickseelert
|
||||
nicobailon
|
||||
ninlds
|
||||
ogulcancelik
|
||||
patrick-kidger
|
||||
paulbettner
|
||||
Perlence
|
||||
pjtf93
|
||||
prateekmedia
|
||||
prathamdby
|
||||
ribelo
|
||||
richardgill
|
||||
robinwander
|
||||
ronyrus
|
||||
roshanasingh4
|
||||
scutifer
|
||||
skuridin
|
||||
steipete
|
||||
svkozak
|
||||
tallshort
|
||||
theBucky
|
||||
thomasmhr
|
||||
tiagoefreitas
|
||||
timolins
|
||||
tmustier
|
||||
tudoroancea
|
||||
unexge
|
||||
vaayne
|
||||
VaclavSynacek
|
||||
vsabavat
|
||||
w-winter
|
||||
Whamp
|
||||
WismutHansen
|
||||
XesGaDeus
|
||||
yevhen
|
||||
badlogictest
|
||||
terrorobe
|
||||
zedrdave
|
||||
mrud
|
||||
toorusr
|
||||
andresaraujo
|
||||
lightningRalf
|
||||
williballenthin
|
||||
masonc15
|
||||
4h9fbZ
|
||||
haoqixu
|
||||
Graffioh
|
||||
charles-cooper
|
||||
emanuelst
|
||||
juanibiapina
|
||||
liby
|
||||
pasky
|
||||
odysseus0
|
||||
giuseppeg
|
||||
michaelpersonal
|
||||
academo
|
||||
PriNova
|
||||
semtexzv
|
||||
jasonish
|
||||
markusn
|
||||
SamFold
|
||||
Soleone
|
||||
virtuald
|
||||
NateSmyth
|
||||
7Sageer
|
||||
MatthieuBizien
|
||||
sumeet
|
||||
marchellodev
|
||||
vedang
|
||||
lucemia
|
||||
mcollina
|
||||
lajarre
|
||||
smithbm2316
|
||||
drewburr
|
||||
gordonhwc
|
||||
deybhayden
|
||||
tintinweb
|
||||
asoules
|
||||
zhahaoyu
|
||||
in0vik
|
||||
jtac
|
||||
yzhg1983
|
||||
smcllns
|
||||
dmmulroy
|
||||
zmberber
|
||||
118
.github/APPROVED_CONTRIBUTORS.vacation
vendored
Normal file
118
.github/APPROVED_CONTRIBUTORS.vacation
vendored
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
# GitHub handles of users approved to submit PRs
|
||||
# One handle per line (without @)
|
||||
# Add new contributors by commenting lgtm on their issue
|
||||
aadishv
|
||||
airtonix
|
||||
aliou
|
||||
aos
|
||||
austinm911
|
||||
banteg
|
||||
ben-vargas
|
||||
butelo
|
||||
can1357
|
||||
CarlosGtrz
|
||||
cau1k
|
||||
cmf
|
||||
crcatala
|
||||
Cursivez
|
||||
cv
|
||||
dannote
|
||||
default-anton
|
||||
dnouri
|
||||
DronNick
|
||||
enisdenjo
|
||||
ferologics
|
||||
fightbulc
|
||||
ghoulr
|
||||
gnattu
|
||||
HACKE-RC
|
||||
hewliyang
|
||||
hjanuschka
|
||||
iamd3vil
|
||||
jblwilliams
|
||||
joshp123
|
||||
jsinge97
|
||||
justram
|
||||
kaofelix
|
||||
kiliman
|
||||
kim0
|
||||
lockmeister
|
||||
LukeFost
|
||||
lukele
|
||||
m-box-mr
|
||||
marckrenn
|
||||
markusylisiurunen
|
||||
mcinteerj
|
||||
melihmucuk
|
||||
mitsuhiko
|
||||
mrexodia
|
||||
nathyong
|
||||
nickseelert
|
||||
nicobailon
|
||||
ninlds
|
||||
ogulcancelik
|
||||
patrick-kidger
|
||||
paulbettner
|
||||
Perlence
|
||||
pjtf93
|
||||
prateekmedia
|
||||
prathamdby
|
||||
ribelo
|
||||
richardgill
|
||||
robinwander
|
||||
ronyrus
|
||||
roshanasingh4
|
||||
scutifer
|
||||
skuridin
|
||||
steipete
|
||||
svkozak
|
||||
tallshort
|
||||
theBucky
|
||||
thomasmhr
|
||||
tiagoefreitas
|
||||
timolins
|
||||
tmustier
|
||||
tudoroancea
|
||||
unexge
|
||||
vaayne
|
||||
VaclavSynacek
|
||||
vsabavat
|
||||
w-winter
|
||||
Whamp
|
||||
WismutHansen
|
||||
XesGaDeus
|
||||
yevhen
|
||||
badlogictest
|
||||
terrorobe
|
||||
zedrdave
|
||||
mrud
|
||||
toorusr
|
||||
andresaraujo
|
||||
lightningRalf
|
||||
williballenthin
|
||||
masonc15
|
||||
4h9fbZ
|
||||
haoqixu
|
||||
Graffioh
|
||||
charles-cooper
|
||||
emanuelst
|
||||
juanibiapina
|
||||
liby
|
||||
pasky
|
||||
odysseus0
|
||||
giuseppeg
|
||||
michaelpersonal
|
||||
academo
|
||||
PriNova
|
||||
semtexzv
|
||||
jasonish
|
||||
markusn
|
||||
SamFold
|
||||
Soleone
|
||||
virtuald
|
||||
NateSmyth
|
||||
7Sageer
|
||||
MatthieuBizien
|
||||
sumeet
|
||||
marchellodev
|
||||
vedang
|
||||
34
.github/ISSUE_TEMPLATE/bug.yml
vendored
Normal file
34
.github/ISSUE_TEMPLATE/bug.yml
vendored
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
name: Bug Report
|
||||
description: Report something that's broken
|
||||
labels: ["bug"]
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: What happened?
|
||||
description: Be specific. Include error messages if any.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: repro
|
||||
attributes:
|
||||
label: Steps to reproduce
|
||||
description: Minimal steps to trigger the bug.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: expected
|
||||
attributes:
|
||||
label: Expected behavior
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: Version
|
||||
description: e.g. 0.49.0
|
||||
validations:
|
||||
required: false
|
||||
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: Questions
|
||||
url: https://discord.com/invite/3cU7Bz4UPx
|
||||
about: Ask questions on Discord instead of opening an issue
|
||||
34
.github/ISSUE_TEMPLATE/contribution.yml
vendored
Normal file
34
.github/ISSUE_TEMPLATE/contribution.yml
vendored
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
name: Contribution Proposal
|
||||
description: Propose a change or feature (required for new contributors before submitting a PR)
|
||||
labels: []
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
**Before you start:** Read [CONTRIBUTING.md](https://github.com/badlogic/pi-mono/blob/main/CONTRIBUTING.md).
|
||||
|
||||
Keep this short. If it doesn't fit on one screen, it's too long. Write in your own voice.
|
||||
|
||||
- type: textarea
|
||||
id: what
|
||||
attributes:
|
||||
label: What do you want to change?
|
||||
description: Be specific and concise.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: why
|
||||
attributes:
|
||||
label: Why?
|
||||
description: What problem does this solve?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: how
|
||||
attributes:
|
||||
label: How? (optional)
|
||||
description: Brief technical approach if you have one in mind.
|
||||
validations:
|
||||
required: false
|
||||
9
.github/oss-weekend.json
vendored
Normal file
9
.github/oss-weekend.json
vendored
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"active": true,
|
||||
"mode": "weekend",
|
||||
"startsAt": "2026-03-22T22:29:34.405Z",
|
||||
"startsAtText": "Sunday, March 22, 2026",
|
||||
"reopensOn": "2026-03-30",
|
||||
"reopensOnText": "Monday, March 30, 2026",
|
||||
"discordUrl": "https://discord.com/invite/3cU7Bz4UPx"
|
||||
}
|
||||
100
.github/workflows/approve-contributor.yml
vendored
Normal file
100
.github/workflows/approve-contributor.yml
vendored
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
name: Approve Contributor
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
approve:
|
||||
if: ${{ !github.event.issue.pull_request }}
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
issues: write
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Add contributor to approved list
|
||||
id: update
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
|
||||
const issueAuthor = context.payload.issue.user.login;
|
||||
const commenter = context.payload.comment.user.login;
|
||||
const commentBody = context.payload.comment.body || '';
|
||||
const approvedFile = '.github/APPROVED_CONTRIBUTORS';
|
||||
|
||||
if (!/^\s*lgtm\b/i.test(commentBody)) {
|
||||
console.log('Comment does not match lgtm');
|
||||
core.setOutput('status', 'skipped');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const { data: permissionLevel } = await github.rest.repos.getCollaboratorPermissionLevel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
username: commenter
|
||||
});
|
||||
|
||||
if (!['admin', 'write'].includes(permissionLevel.permission)) {
|
||||
console.log(`${commenter} does not have write access`);
|
||||
core.setOutput('status', 'skipped');
|
||||
return;
|
||||
}
|
||||
} catch (error) {
|
||||
console.log(`${commenter} does not have collaborator access`);
|
||||
core.setOutput('status', 'skipped');
|
||||
return;
|
||||
}
|
||||
|
||||
let content = fs.readFileSync(approvedFile, 'utf8');
|
||||
const approvedList = content
|
||||
.split('\n')
|
||||
.map(line => line.trim().toLowerCase())
|
||||
.filter(line => line && !line.startsWith('#'));
|
||||
|
||||
if (approvedList.includes(issueAuthor.toLowerCase())) {
|
||||
console.log(`${issueAuthor} is already approved`);
|
||||
core.setOutput('status', 'already');
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: `@${issueAuthor} is already in the approved contributors list.`
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
content = content.trimEnd() + '\n' + issueAuthor + '\n';
|
||||
fs.writeFileSync(approvedFile, content);
|
||||
|
||||
console.log(`Added ${issueAuthor} to approved contributors`);
|
||||
core.setOutput('status', 'added');
|
||||
|
||||
- name: Commit and push
|
||||
if: steps.update.outputs.status == 'added'
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git add .github/APPROVED_CONTRIBUTORS
|
||||
git diff --staged --quiet || git commit -m "chore: approve contributor ${{ github.event.issue.user.login }}"
|
||||
git push
|
||||
|
||||
- name: Comment on issue
|
||||
if: steps.update.outputs.status == 'added'
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const issueAuthor = context.payload.issue.user.login;
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: `@${issueAuthor} has been added to the approved contributors list. You can now submit PRs. Thanks for contributing!`
|
||||
});
|
||||
79
.github/workflows/build-binaries.yml
vendored
Normal file
79
.github/workflows/build-binaries.yml
vendored
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
name: Build Binaries
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: 'Tag to build (e.g., v0.12.0)'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.event.inputs.tag || github.ref_name }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
ref: ${{ env.RELEASE_TAG }}
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@4bc047ad259df6fc24a6c9b0f9a0cb08cf17fbe5 # v2.0.1
|
||||
with:
|
||||
bun-version: 1.2.20
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0
|
||||
with:
|
||||
node-version: '22'
|
||||
registry-url: 'https://registry.npmjs.org'
|
||||
|
||||
- name: Build binaries
|
||||
run: ./scripts/build-binaries.sh
|
||||
|
||||
- name: Extract changelog for this version
|
||||
id: changelog
|
||||
run: |
|
||||
VERSION="${RELEASE_TAG}"
|
||||
VERSION="${VERSION#v}" # Remove 'v' prefix
|
||||
|
||||
# Extract changelog section for this version
|
||||
cd packages/coding-agent
|
||||
awk "/^## \[${VERSION}\]/{flag=1; next} /^## \[/{flag=0} flag" CHANGELOG.md > /tmp/release-notes.md
|
||||
|
||||
# If empty, use a default message
|
||||
if [ ! -s /tmp/release-notes.md ]; then
|
||||
echo "Release ${VERSION}" > /tmp/release-notes.md
|
||||
fi
|
||||
|
||||
- name: Create GitHub Release and upload binaries
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
cd packages/coding-agent/binaries
|
||||
|
||||
# Create release with changelog notes (or update if exists)
|
||||
gh release create "${RELEASE_TAG}" \
|
||||
--title "${RELEASE_TAG}" \
|
||||
--notes-file /tmp/release-notes.md \
|
||||
pi-darwin-arm64.tar.gz \
|
||||
pi-darwin-x64.tar.gz \
|
||||
pi-linux-x64.tar.gz \
|
||||
pi-linux-arm64.tar.gz \
|
||||
pi-windows-x64.zip \
|
||||
2>/dev/null || \
|
||||
gh release upload "${RELEASE_TAG}" \
|
||||
pi-darwin-arm64.tar.gz \
|
||||
pi-darwin-x64.tar.gz \
|
||||
pi-linux-x64.tar.gz \
|
||||
pi-linux-arm64.tar.gz \
|
||||
pi-windows-x64.zip \
|
||||
--clobber
|
||||
42
.github/workflows/ci.yml
vendored
Normal file
42
.github/workflows/ci.yml
vendored
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
concurrency:
|
||||
group: ci-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-check-test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: npm
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libcairo2-dev libpango1.0-dev libjpeg-dev libgif-dev librsvg2-dev fd-find ripgrep
|
||||
sudo ln -s $(which fdfind) /usr/local/bin/fd
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
|
||||
- name: Build
|
||||
run: npm run build
|
||||
|
||||
- name: Check
|
||||
run: npm run check
|
||||
|
||||
- name: Test
|
||||
run: npm test
|
||||
98
.github/workflows/oss-weekend-issues.yml
vendored
Normal file
98
.github/workflows/oss-weekend-issues.yml
vendored
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
name: OSS Weekend Issues
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [opened]
|
||||
|
||||
jobs:
|
||||
close-issues-during-weekend:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
steps:
|
||||
- name: Close new issues during OSS weekend
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const issueAuthor = context.payload.issue.user.login;
|
||||
const defaultBranch = context.payload.repository.default_branch;
|
||||
|
||||
if (issueAuthor.endsWith('[bot]') || issueAuthor === 'dependabot[bot]') {
|
||||
console.log(`Skipping bot: ${issueAuthor}`);
|
||||
return;
|
||||
}
|
||||
|
||||
async function getPermission(username) {
|
||||
try {
|
||||
const { data: permissionLevel } = await github.rest.repos.getCollaboratorPermissionLevel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
username,
|
||||
});
|
||||
return permissionLevel.permission;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
async function getTextFile(path) {
|
||||
const { data: fileContent } = await github.rest.repos.getContent({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
path,
|
||||
ref: defaultBranch,
|
||||
});
|
||||
|
||||
if (!('content' in fileContent) || typeof fileContent.content !== 'string') {
|
||||
throw new Error(`Expected file content for ${path}`);
|
||||
}
|
||||
|
||||
return Buffer.from(fileContent.content, 'base64').toString('utf8');
|
||||
}
|
||||
|
||||
const permission = await getPermission(issueAuthor);
|
||||
if (['admin', 'maintain', 'write'].includes(permission)) {
|
||||
console.log(`${issueAuthor} is a collaborator with ${permission} access`);
|
||||
return;
|
||||
}
|
||||
|
||||
let weekendState;
|
||||
try {
|
||||
weekendState = JSON.parse(await getTextFile('.github/oss-weekend.json'));
|
||||
} catch (error) {
|
||||
if (error && typeof error === 'object' && 'status' in error && error.status === 404) {
|
||||
console.log('OSS weekend is not active');
|
||||
return;
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
|
||||
if (!weekendState?.active) {
|
||||
console.log('OSS weekend is not active');
|
||||
return;
|
||||
}
|
||||
|
||||
const reopenDate = weekendState.reopensOnText || weekendState.reopensOn || 'after the weekend';
|
||||
const discordUrl = weekendState.discordUrl || 'https://discord.com/invite/3cU7Bz4UPx';
|
||||
const message = [
|
||||
`Hi @${issueAuthor}, thanks for opening an issue.`,
|
||||
'',
|
||||
`OSS weekend is active until ${reopenDate}, so new issues are being auto-closed for now.`,
|
||||
'',
|
||||
`Please reopen or submit this issue again after ${reopenDate}. For support, join [Discord](${discordUrl}).`,
|
||||
].join('\n');
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: message,
|
||||
});
|
||||
|
||||
await github.rest.issues.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
state: 'closed',
|
||||
});
|
||||
132
.github/workflows/pr-gate.yml
vendored
Normal file
132
.github/workflows/pr-gate.yml
vendored
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
name: PR Gate
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened]
|
||||
|
||||
jobs:
|
||||
check-contributor:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Check if contributor is approved
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const prAuthor = context.payload.pull_request.user.login;
|
||||
const defaultBranch = context.payload.repository.default_branch;
|
||||
|
||||
if (prAuthor.endsWith('[bot]') || prAuthor === 'dependabot[bot]') {
|
||||
console.log(`Skipping bot: ${prAuthor}`);
|
||||
return;
|
||||
}
|
||||
|
||||
async function getPermission(username) {
|
||||
try {
|
||||
const { data: permissionLevel } = await github.rest.repos.getCollaboratorPermissionLevel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
username,
|
||||
});
|
||||
return permissionLevel.permission;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
async function getTextFile(path) {
|
||||
const { data: fileContent } = await github.rest.repos.getContent({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
path,
|
||||
ref: defaultBranch,
|
||||
});
|
||||
|
||||
if (!('content' in fileContent) || typeof fileContent.content !== 'string') {
|
||||
throw new Error(`Expected file content for ${path}`);
|
||||
}
|
||||
|
||||
return Buffer.from(fileContent.content, 'base64').toString('utf8');
|
||||
}
|
||||
|
||||
async function closePullRequest(message) {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.payload.pull_request.number,
|
||||
body: message,
|
||||
});
|
||||
|
||||
await github.rest.pulls.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: context.payload.pull_request.number,
|
||||
state: 'closed',
|
||||
});
|
||||
}
|
||||
|
||||
const permission = await getPermission(prAuthor);
|
||||
if (['admin', 'maintain', 'write'].includes(permission)) {
|
||||
console.log(`${prAuthor} is a collaborator with ${permission} access`);
|
||||
return;
|
||||
}
|
||||
|
||||
const approvedContent = await getTextFile('.github/APPROVED_CONTRIBUTORS');
|
||||
const approvedList = approvedContent
|
||||
.split('\n')
|
||||
.map(line => line.trim().toLowerCase())
|
||||
.filter(line => line && !line.startsWith('#'));
|
||||
const isApprovedContributor = approvedList.includes(prAuthor.toLowerCase());
|
||||
|
||||
let weekendState = null;
|
||||
try {
|
||||
weekendState = JSON.parse(await getTextFile('.github/oss-weekend.json'));
|
||||
} catch (error) {
|
||||
if (!(error && typeof error === 'object' && 'status' in error && error.status === 404)) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
if (weekendState?.active && isApprovedContributor) {
|
||||
console.log(`${prAuthor} is approved, but OSS weekend is active`);
|
||||
|
||||
const reopenDate = weekendState.reopensOnText || weekendState.reopensOn || 'after the weekend';
|
||||
const discordUrl = weekendState.discordUrl || 'https://discord.com/invite/3cU7Bz4UPx';
|
||||
const message = [
|
||||
`Hi @${prAuthor}, thanks for the PR.`,
|
||||
'',
|
||||
`OSS weekend is active until ${reopenDate}, so external PRs are being paused for now.`,
|
||||
'',
|
||||
'You are already on the approved contributors list, so you can resubmit this PR after the weekend without reapproval.',
|
||||
'',
|
||||
`This PR will be closed automatically. For support, join [Discord](${discordUrl}).`,
|
||||
].join('\n');
|
||||
|
||||
await closePullRequest(message);
|
||||
return;
|
||||
}
|
||||
|
||||
if (isApprovedContributor) {
|
||||
console.log(`${prAuthor} is in the approved contributors list`);
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`${prAuthor} is not approved, closing PR`);
|
||||
|
||||
const message = [
|
||||
`Hi @${prAuthor}, thanks for your interest in contributing!`,
|
||||
'',
|
||||
'We ask new contributors to open an issue first before submitting a PR. This helps us discuss the approach and avoid wasted effort.',
|
||||
'',
|
||||
'**Next steps:**',
|
||||
'1. Open an issue describing what you want to change and why (keep it concise, write in your human voice, AI slop will be closed)',
|
||||
'2. Once a maintainer approves with `lgtm`, you\'ll be added to the approved contributors list',
|
||||
'3. Then you can submit your PR',
|
||||
'',
|
||||
`This PR will be closed automatically. See https://github.com/${context.repo.owner}/${context.repo.repo}/blob/${defaultBranch}/CONTRIBUTING.md for more details.`,
|
||||
].join('\n');
|
||||
|
||||
await closePullRequest(message);
|
||||
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
node_modules/
|
||||
dist/
|
||||
*.js.map
|
||||
.DS_Store
|
||||
*.tsbuildinfo
|
||||
.env
|
||||
40
.husky/pre-commit
Executable file
40
.husky/pre-commit
Executable file
|
|
@ -0,0 +1,40 @@
|
|||
#!/bin/sh
|
||||
|
||||
# Get list of staged files before running check
|
||||
STAGED_FILES=$(git diff --cached --name-only)
|
||||
|
||||
# Run the check script (formatting, linting, and type checking)
|
||||
echo "Running formatting, linting, and type checking..."
|
||||
npm run check
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "❌ Checks failed. Please fix the errors before committing."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
RUN_BROWSER_SMOKE=0
|
||||
for file in $STAGED_FILES; do
|
||||
case "$file" in
|
||||
packages/ai/*|packages/web-ui/*|package.json|package-lock.json)
|
||||
RUN_BROWSER_SMOKE=1
|
||||
break
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ $RUN_BROWSER_SMOKE -eq 1 ]; then
|
||||
echo "Running browser smoke check..."
|
||||
npm run check:browser-smoke
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "❌ Browser smoke check failed."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Restage files that were previously staged and may have been modified by formatting
|
||||
for file in $STAGED_FILES; do
|
||||
if [ -f "$file" ]; then
|
||||
git add "$file"
|
||||
fi
|
||||
done
|
||||
|
||||
echo "✅ All pre-commit checks passed!"
|
||||
218
.jae/extensions/diff.ts
Normal file
218
.jae/extensions/diff.ts
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
/**
|
||||
* Diff Extension
|
||||
*
|
||||
* /diff command shows modified/deleted/new files from git status and opens
|
||||
* the selected file in VS Code's diff view.
|
||||
*/
|
||||
|
||||
import type { ExtensionAPI } from "@jaeswift/jae-coding-agent";
|
||||
import { DynamicBorder } from "@jaeswift/jae-coding-agent";
|
||||
import { Container, Key, matchesKey, type SelectItem, SelectList, Text } from "@jaeswift/jae-tui";
|
||||
|
||||
interface FileInfo {
|
||||
status: string;
|
||||
statusLabel: string;
|
||||
file: string;
|
||||
}
|
||||
|
||||
export default function (pi: ExtensionAPI) {
|
||||
pi.registerCommand("diff", {
|
||||
description: "Show git changes and open in VS Code diff view",
|
||||
handler: async (_args, ctx) => {
|
||||
if (!ctx.hasUI) {
|
||||
ctx.ui.notify("No UI available", "error");
|
||||
return;
|
||||
}
|
||||
|
||||
// Get changed files from git status
|
||||
const result = await pi.exec("git", ["status", "--porcelain"], { cwd: ctx.cwd });
|
||||
|
||||
if (result.code !== 0) {
|
||||
ctx.ui.notify(`git status failed: ${result.stderr}`, "error");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!result.stdout || !result.stdout.trim()) {
|
||||
ctx.ui.notify("No changes in working tree", "info");
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse git status output
|
||||
// Format: XY filename (where XY is two-letter status, then space, then filename)
|
||||
const lines = result.stdout.split("\n");
|
||||
const files: FileInfo[] = [];
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.length < 4) continue; // Need at least "XY f"
|
||||
|
||||
const status = line.slice(0, 2);
|
||||
const file = line.slice(2).trimStart();
|
||||
|
||||
// Translate status codes to short labels
|
||||
let statusLabel: string;
|
||||
if (status.includes("M")) statusLabel = "M";
|
||||
else if (status.includes("A")) statusLabel = "A";
|
||||
else if (status.includes("D")) statusLabel = "D";
|
||||
else if (status.includes("?")) statusLabel = "?";
|
||||
else if (status.includes("R")) statusLabel = "R";
|
||||
else if (status.includes("C")) statusLabel = "C";
|
||||
else statusLabel = status.trim() || "~";
|
||||
|
||||
files.push({ status: statusLabel, statusLabel, file });
|
||||
}
|
||||
|
||||
if (files.length === 0) {
|
||||
ctx.ui.notify("No changes found", "info");
|
||||
return;
|
||||
}
|
||||
|
||||
const WINDOWS_UNSAFE_CMD_CHARS_RE = /[&|<>^%\r\n]/;
|
||||
const quoteCmdArg = (value: string) => `"${value.replace(/"/g, '""')}"`;
|
||||
|
||||
const openWithCode = async (file: string) => {
|
||||
if (process.platform === "win32") {
|
||||
if (WINDOWS_UNSAFE_CMD_CHARS_RE.test(file)) {
|
||||
ctx.ui.notify(
|
||||
`Refusing to open ${file}: path contains Windows cmd metacharacters (& | < > ^ % or newline).`,
|
||||
"error",
|
||||
);
|
||||
return null;
|
||||
}
|
||||
const commandLine = `code -g ${quoteCmdArg(file)}`;
|
||||
return pi.exec("cmd", ["/d", "/s", "/c", commandLine], { cwd: ctx.cwd });
|
||||
}
|
||||
return pi.exec("code", ["-g", file], { cwd: ctx.cwd });
|
||||
};
|
||||
|
||||
const openSelected = async (fileInfo: FileInfo): Promise<void> => {
|
||||
try {
|
||||
// Open in VS Code diff view.
|
||||
// For untracked files, git difftool won't work, so fall back to just opening the file.
|
||||
if (fileInfo.status === "?") {
|
||||
const openResult = await openWithCode(fileInfo.file);
|
||||
if (!openResult) return;
|
||||
if (openResult.code !== 0) {
|
||||
const openStderr = openResult.stderr.trim();
|
||||
ctx.ui.notify(
|
||||
`Failed to open ${fileInfo.file} (exit ${openResult.code})${openStderr ? `: ${openStderr}` : ""}`,
|
||||
"error",
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const diffResult = await pi.exec("git", ["difftool", "-y", "--tool=vscode", fileInfo.file], {
|
||||
cwd: ctx.cwd,
|
||||
});
|
||||
if (diffResult.code !== 0) {
|
||||
const diffStderr = diffResult.stderr.trim();
|
||||
ctx.ui.notify(
|
||||
`Failed to show diff with vscode for ${fileInfo.file} (exit ${diffResult.code})${diffStderr ? `: ${diffStderr}` : ""}`,
|
||||
"error",
|
||||
);
|
||||
ctx.ui.notify(
|
||||
"Troubleshooting: check git difftool config (e.g. `git config --get difftool.vscode.cmd`).",
|
||||
"info",
|
||||
);
|
||||
|
||||
const openResult = await openWithCode(fileInfo.file);
|
||||
if (!openResult) return;
|
||||
if (openResult.code !== 0) {
|
||||
const openStderr = openResult.stderr.trim();
|
||||
ctx.ui.notify(
|
||||
`Failed to open ${fileInfo.file} (exit ${openResult.code})${openStderr ? `: ${openStderr}` : ""}`,
|
||||
"error",
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
ctx.ui.notify(`Failed to open ${fileInfo.file}: ${message}`, "error");
|
||||
}
|
||||
};
|
||||
|
||||
// Show file picker with SelectList
|
||||
await ctx.ui.custom<void>((tui, theme, _kb, done) => {
|
||||
const container = new Container();
|
||||
|
||||
// Top border
|
||||
container.addChild(new DynamicBorder((s: string) => theme.fg("accent", s)));
|
||||
|
||||
// Title
|
||||
container.addChild(new Text(theme.fg("accent", theme.bold(" Select file to diff")), 0, 0));
|
||||
|
||||
// Build select items with colored status
|
||||
const items: SelectItem[] = files.map((f) => {
|
||||
let statusColor: string;
|
||||
switch (f.status) {
|
||||
case "M":
|
||||
statusColor = theme.fg("warning", f.status);
|
||||
break;
|
||||
case "A":
|
||||
statusColor = theme.fg("success", f.status);
|
||||
break;
|
||||
case "D":
|
||||
statusColor = theme.fg("error", f.status);
|
||||
break;
|
||||
case "?":
|
||||
statusColor = theme.fg("muted", f.status);
|
||||
break;
|
||||
default:
|
||||
statusColor = theme.fg("dim", f.status);
|
||||
}
|
||||
return {
|
||||
value: f,
|
||||
label: `${statusColor} ${f.file}`,
|
||||
};
|
||||
});
|
||||
|
||||
const visibleRows = Math.min(files.length, 15);
|
||||
let currentIndex = 0;
|
||||
|
||||
const selectList = new SelectList(items, visibleRows, {
|
||||
selectedPrefix: (t) => theme.fg("accent", t),
|
||||
selectedText: (t) => t, // Keep existing colors
|
||||
description: (t) => theme.fg("muted", t),
|
||||
scrollInfo: (t) => theme.fg("dim", t),
|
||||
noMatch: (t) => theme.fg("warning", t),
|
||||
});
|
||||
selectList.onSelect = (item) => {
|
||||
void openSelected(item.value as FileInfo);
|
||||
};
|
||||
selectList.onCancel = () => done();
|
||||
selectList.onSelectionChange = (item) => {
|
||||
currentIndex = items.indexOf(item);
|
||||
};
|
||||
container.addChild(selectList);
|
||||
|
||||
// Help text
|
||||
container.addChild(
|
||||
new Text(theme.fg("dim", " ↑↓ navigate • ←→ page • enter open • esc close"), 0, 0),
|
||||
);
|
||||
|
||||
// Bottom border
|
||||
container.addChild(new DynamicBorder((s: string) => theme.fg("accent", s)));
|
||||
|
||||
return {
|
||||
render: (w) => container.render(w),
|
||||
invalidate: () => container.invalidate(),
|
||||
handleInput: (data) => {
|
||||
// Add paging with left/right
|
||||
if (matchesKey(data, Key.left)) {
|
||||
// Page up - clamp to 0
|
||||
currentIndex = Math.max(0, currentIndex - visibleRows);
|
||||
selectList.setSelectedIndex(currentIndex);
|
||||
} else if (matchesKey(data, Key.right)) {
|
||||
// Page down - clamp to last
|
||||
currentIndex = Math.min(items.length - 1, currentIndex + visibleRows);
|
||||
selectList.setSelectedIndex(currentIndex);
|
||||
} else {
|
||||
selectList.handleInput(data);
|
||||
}
|
||||
tui.requestRender();
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
199
.jae/extensions/files.ts
Normal file
199
.jae/extensions/files.ts
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
/**
|
||||
* Files Extension
|
||||
*
|
||||
* /files command lists all files the model has read/written/edited in the active session branch,
|
||||
* coalesced by path and sorted newest first. Selecting a file opens it in VS Code.
|
||||
*/
|
||||
|
||||
import type { ExtensionAPI } from "@jaeswift/jae-coding-agent";
|
||||
import { DynamicBorder } from "@jaeswift/jae-coding-agent";
|
||||
import { Container, Key, matchesKey, type SelectItem, SelectList, Text } from "@jaeswift/jae-tui";
|
||||
|
||||
interface FileEntry {
|
||||
path: string;
|
||||
operations: Set<"read" | "write" | "edit">;
|
||||
lastTimestamp: number;
|
||||
}
|
||||
|
||||
type FileToolName = "read" | "write" | "edit";
|
||||
|
||||
export default function (pi: ExtensionAPI) {
|
||||
pi.registerCommand("files", {
|
||||
description: "Show files read/written/edited in this session",
|
||||
handler: async (_args, ctx) => {
|
||||
if (!ctx.hasUI) {
|
||||
ctx.ui.notify("No UI available", "error");
|
||||
return;
|
||||
}
|
||||
|
||||
// Get the current branch (path from leaf to root)
|
||||
const branch = ctx.sessionManager.getBranch();
|
||||
|
||||
// First pass: collect tool calls (id -> {path, name}) from assistant messages
|
||||
const toolCalls = new Map<string, { path: string; name: FileToolName; timestamp: number }>();
|
||||
|
||||
for (const entry of branch) {
|
||||
if (entry.type !== "message") continue;
|
||||
const msg = entry.message;
|
||||
|
||||
if (msg.role === "assistant" && Array.isArray(msg.content)) {
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "toolCall") {
|
||||
const name = block.name;
|
||||
if (name === "read" || name === "write" || name === "edit") {
|
||||
const path = block.arguments?.path;
|
||||
if (path && typeof path === "string") {
|
||||
toolCalls.set(block.id, { path, name, timestamp: msg.timestamp });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: match tool results to get the actual execution timestamp
|
||||
const fileMap = new Map<string, FileEntry>();
|
||||
|
||||
for (const entry of branch) {
|
||||
if (entry.type !== "message") continue;
|
||||
const msg = entry.message;
|
||||
|
||||
if (msg.role === "toolResult") {
|
||||
const toolCall = toolCalls.get(msg.toolCallId);
|
||||
if (!toolCall) continue;
|
||||
|
||||
const { path, name } = toolCall;
|
||||
const timestamp = msg.timestamp;
|
||||
|
||||
const existing = fileMap.get(path);
|
||||
if (existing) {
|
||||
existing.operations.add(name);
|
||||
if (timestamp > existing.lastTimestamp) {
|
||||
existing.lastTimestamp = timestamp;
|
||||
}
|
||||
} else {
|
||||
fileMap.set(path, {
|
||||
path,
|
||||
operations: new Set([name]),
|
||||
lastTimestamp: timestamp,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (fileMap.size === 0) {
|
||||
ctx.ui.notify("No files read/written/edited in this session", "info");
|
||||
return;
|
||||
}
|
||||
|
||||
// Sort by most recent first
|
||||
const files = Array.from(fileMap.values()).sort((a, b) => b.lastTimestamp - a.lastTimestamp);
|
||||
|
||||
const WINDOWS_UNSAFE_CMD_CHARS_RE = /[&|<>^%\r\n]/;
|
||||
const quoteCmdArg = (value: string) => `"${value.replace(/"/g, '""')}"`;
|
||||
|
||||
const openWithCode = async (path: string) => {
|
||||
if (process.platform === "win32") {
|
||||
if (WINDOWS_UNSAFE_CMD_CHARS_RE.test(path)) {
|
||||
ctx.ui.notify(
|
||||
`Refusing to open ${path}: path contains Windows cmd metacharacters (& | < > ^ % or newline).`,
|
||||
"error",
|
||||
);
|
||||
return null;
|
||||
}
|
||||
const commandLine = `code -g ${quoteCmdArg(path)}`;
|
||||
return pi.exec("cmd", ["/d", "/s", "/c", commandLine], { cwd: ctx.cwd });
|
||||
}
|
||||
return pi.exec("code", ["-g", path], { cwd: ctx.cwd });
|
||||
};
|
||||
|
||||
const openSelected = async (file: FileEntry): Promise<void> => {
|
||||
try {
|
||||
const openResult = await openWithCode(file.path);
|
||||
if (!openResult) return;
|
||||
if (openResult.code !== 0) {
|
||||
const openStderr = openResult.stderr.trim();
|
||||
ctx.ui.notify(
|
||||
`Failed to open ${file.path} (exit ${openResult.code})${openStderr ? `: ${openStderr}` : ""}`,
|
||||
"error",
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
ctx.ui.notify(`Failed to open ${file.path}: ${message}`, "error");
|
||||
}
|
||||
};
|
||||
|
||||
// Show file picker with SelectList
|
||||
await ctx.ui.custom<void>((tui, theme, _kb, done) => {
|
||||
const container = new Container();
|
||||
|
||||
// Top border
|
||||
container.addChild(new DynamicBorder((s: string) => theme.fg("accent", s)));
|
||||
|
||||
// Title
|
||||
container.addChild(new Text(theme.fg("accent", theme.bold(" Select file to open")), 0, 0));
|
||||
|
||||
// Build select items with colored operations
|
||||
const items: SelectItem[] = files.map((f) => {
|
||||
const ops: string[] = [];
|
||||
if (f.operations.has("read")) ops.push(theme.fg("muted", "R"));
|
||||
if (f.operations.has("write")) ops.push(theme.fg("success", "W"));
|
||||
if (f.operations.has("edit")) ops.push(theme.fg("warning", "E"));
|
||||
const opsLabel = ops.join("");
|
||||
return {
|
||||
value: f,
|
||||
label: `${opsLabel} ${f.path}`,
|
||||
};
|
||||
});
|
||||
|
||||
const visibleRows = Math.min(files.length, 15);
|
||||
let currentIndex = 0;
|
||||
|
||||
const selectList = new SelectList(items, visibleRows, {
|
||||
selectedPrefix: (t) => theme.fg("accent", t),
|
||||
selectedText: (t) => t, // Keep existing colors
|
||||
description: (t) => theme.fg("muted", t),
|
||||
scrollInfo: (t) => theme.fg("dim", t),
|
||||
noMatch: (t) => theme.fg("warning", t),
|
||||
});
|
||||
selectList.onSelect = (item) => {
|
||||
void openSelected(item.value as FileEntry);
|
||||
};
|
||||
selectList.onCancel = () => done();
|
||||
selectList.onSelectionChange = (item) => {
|
||||
currentIndex = items.indexOf(item);
|
||||
};
|
||||
container.addChild(selectList);
|
||||
|
||||
// Help text
|
||||
container.addChild(
|
||||
new Text(theme.fg("dim", " ↑↓ navigate • ←→ page • enter open • esc close"), 0, 0),
|
||||
);
|
||||
|
||||
// Bottom border
|
||||
container.addChild(new DynamicBorder((s: string) => theme.fg("accent", s)));
|
||||
|
||||
return {
|
||||
render: (w) => container.render(w),
|
||||
invalidate: () => container.invalidate(),
|
||||
handleInput: (data) => {
|
||||
// Add paging with left/right
|
||||
if (matchesKey(data, Key.left)) {
|
||||
// Page up - clamp to 0
|
||||
currentIndex = Math.max(0, currentIndex - visibleRows);
|
||||
selectList.setSelectedIndex(currentIndex);
|
||||
} else if (matchesKey(data, Key.right)) {
|
||||
// Page down - clamp to last
|
||||
currentIndex = Math.min(items.length - 1, currentIndex + visibleRows);
|
||||
selectList.setSelectedIndex(currentIndex);
|
||||
} else {
|
||||
selectList.handleInput(data);
|
||||
}
|
||||
tui.requestRender();
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
158
.jae/extensions/prompt-url-widget.ts
Normal file
158
.jae/extensions/prompt-url-widget.ts
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
import { DynamicBorder, type ExtensionAPI, type ExtensionContext } from "@jaeswift/jae-coding-agent";
|
||||
import { Container, Text } from "@jaeswift/jae-tui";
|
||||
|
||||
const PR_PROMPT_PATTERN = /^\s*You are given one or more GitHub PR URLs:\s*(\S+)/im;
|
||||
const ISSUE_PROMPT_PATTERN = /^\s*Analyze GitHub issue\(s\):\s*(\S+)/im;
|
||||
|
||||
type PromptMatch = {
|
||||
kind: "pr" | "issue";
|
||||
url: string;
|
||||
};
|
||||
|
||||
type GhMetadata = {
|
||||
title?: string;
|
||||
author?: {
|
||||
login?: string;
|
||||
name?: string | null;
|
||||
};
|
||||
};
|
||||
|
||||
function extractPromptMatch(prompt: string): PromptMatch | undefined {
|
||||
const prMatch = prompt.match(PR_PROMPT_PATTERN);
|
||||
if (prMatch?.[1]) {
|
||||
return { kind: "pr", url: prMatch[1].trim() };
|
||||
}
|
||||
|
||||
const issueMatch = prompt.match(ISSUE_PROMPT_PATTERN);
|
||||
if (issueMatch?.[1]) {
|
||||
return { kind: "issue", url: issueMatch[1].trim() };
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
async function fetchGhMetadata(
|
||||
pi: ExtensionAPI,
|
||||
kind: PromptMatch["kind"],
|
||||
url: string,
|
||||
): Promise<GhMetadata | undefined> {
|
||||
const args =
|
||||
kind === "pr" ? ["pr", "view", url, "--json", "title,author"] : ["issue", "view", url, "--json", "title,author"];
|
||||
|
||||
try {
|
||||
const result = await pi.exec("gh", args);
|
||||
if (result.code !== 0 || !result.stdout) return undefined;
|
||||
return JSON.parse(result.stdout) as GhMetadata;
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
function formatAuthor(author?: GhMetadata["author"]): string | undefined {
|
||||
if (!author) return undefined;
|
||||
const name = author.name?.trim();
|
||||
const login = author.login?.trim();
|
||||
if (name && login) return `${name} (@${login})`;
|
||||
if (login) return `@${login}`;
|
||||
if (name) return name;
|
||||
return undefined;
|
||||
}
|
||||
|
||||
export default function promptUrlWidgetExtension(pi: ExtensionAPI) {
|
||||
const setWidget = (ctx: ExtensionContext, match: PromptMatch, title?: string, authorText?: string) => {
|
||||
ctx.ui.setWidget("prompt-url", (_tui, thm) => {
|
||||
const titleText = title ? thm.fg("accent", title) : thm.fg("accent", match.url);
|
||||
const authorLine = authorText ? thm.fg("muted", authorText) : undefined;
|
||||
const urlLine = thm.fg("dim", match.url);
|
||||
|
||||
const lines = [titleText];
|
||||
if (authorLine) lines.push(authorLine);
|
||||
lines.push(urlLine);
|
||||
|
||||
const container = new Container();
|
||||
container.addChild(new DynamicBorder((s: string) => thm.fg("muted", s)));
|
||||
container.addChild(new Text(lines.join("\n"), 1, 0));
|
||||
return container;
|
||||
});
|
||||
};
|
||||
|
||||
const applySessionName = (ctx: ExtensionContext, match: PromptMatch, title?: string) => {
|
||||
const label = match.kind === "pr" ? "PR" : "Issue";
|
||||
const trimmedTitle = title?.trim();
|
||||
const fallbackName = `${label}: ${match.url}`;
|
||||
const desiredName = trimmedTitle ? `${label}: ${trimmedTitle} (${match.url})` : fallbackName;
|
||||
const currentName = pi.getSessionName()?.trim();
|
||||
if (!currentName) {
|
||||
pi.setSessionName(desiredName);
|
||||
return;
|
||||
}
|
||||
if (currentName === match.url || currentName === fallbackName) {
|
||||
pi.setSessionName(desiredName);
|
||||
}
|
||||
};
|
||||
|
||||
pi.on("before_agent_start", async (event, ctx) => {
|
||||
if (!ctx.hasUI) return;
|
||||
const match = extractPromptMatch(event.prompt);
|
||||
if (!match) {
|
||||
return;
|
||||
}
|
||||
|
||||
setWidget(ctx, match);
|
||||
applySessionName(ctx, match);
|
||||
void fetchGhMetadata(pi, match.kind, match.url).then((meta) => {
|
||||
const title = meta?.title?.trim();
|
||||
const authorText = formatAuthor(meta?.author);
|
||||
setWidget(ctx, match, title, authorText);
|
||||
applySessionName(ctx, match, title);
|
||||
});
|
||||
});
|
||||
|
||||
pi.on("session_switch", async (_event, ctx) => {
|
||||
rebuildFromSession(ctx);
|
||||
});
|
||||
|
||||
const getUserText = (content: string | { type: string; text?: string }[] | undefined): string => {
|
||||
if (!content) return "";
|
||||
if (typeof content === "string") return content;
|
||||
return (
|
||||
content
|
||||
.filter((block): block is { type: "text"; text: string } => block.type === "text")
|
||||
.map((block) => block.text)
|
||||
.join("\n") ?? ""
|
||||
);
|
||||
};
|
||||
|
||||
const rebuildFromSession = (ctx: ExtensionContext) => {
|
||||
if (!ctx.hasUI) return;
|
||||
|
||||
const entries = ctx.sessionManager.getEntries();
|
||||
const lastMatch = [...entries].reverse().find((entry) => {
|
||||
if (entry.type !== "message" || entry.message.role !== "user") return false;
|
||||
const text = getUserText(entry.message.content);
|
||||
return !!extractPromptMatch(text);
|
||||
});
|
||||
|
||||
const content =
|
||||
lastMatch?.type === "message" && lastMatch.message.role === "user" ? lastMatch.message.content : undefined;
|
||||
const text = getUserText(content);
|
||||
const match = text ? extractPromptMatch(text) : undefined;
|
||||
if (!match) {
|
||||
ctx.ui.setWidget("prompt-url", undefined);
|
||||
return;
|
||||
}
|
||||
|
||||
setWidget(ctx, match);
|
||||
applySessionName(ctx, match);
|
||||
void fetchGhMetadata(pi, match.kind, match.url).then((meta) => {
|
||||
const title = meta?.title?.trim();
|
||||
const authorText = formatAuthor(meta?.author);
|
||||
setWidget(ctx, match, title, authorText);
|
||||
applySessionName(ctx, match, title);
|
||||
});
|
||||
};
|
||||
|
||||
pi.on("session_start", async (_event, ctx) => {
|
||||
rebuildFromSession(ctx);
|
||||
});
|
||||
}
|
||||
24
.jae/extensions/redraws.ts
Normal file
24
.jae/extensions/redraws.ts
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
/**
|
||||
* Redraws Extension
|
||||
*
|
||||
* Exposes /tui to show TUI redraw stats.
|
||||
*/
|
||||
|
||||
import type { ExtensionAPI } from "@jaeswift/jae-coding-agent";
|
||||
import { Text } from "@jaeswift/jae-tui";
|
||||
|
||||
export default function (pi: ExtensionAPI) {
|
||||
pi.registerCommand("tui", {
|
||||
description: "Show TUI stats",
|
||||
handler: async (_args, ctx) => {
|
||||
if (!ctx.hasUI) return;
|
||||
let redraws = 0;
|
||||
await ctx.ui.custom<void>((tui, _theme, _keybindings, done) => {
|
||||
redraws = tui.fullRedraws;
|
||||
done(undefined);
|
||||
return new Text("", 0, 0);
|
||||
});
|
||||
ctx.ui.notify(`TUI full redraws: ${redraws}`, "info");
|
||||
},
|
||||
});
|
||||
}
|
||||
47
.jae/extensions/tps.ts
Normal file
47
.jae/extensions/tps.ts
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
import type { AssistantMessage } from "@jaeswift/jae-ai";
|
||||
import type { ExtensionAPI } from "@jaeswift/jae-coding-agent";
|
||||
|
||||
function isAssistantMessage(message: unknown): message is AssistantMessage {
|
||||
if (!message || typeof message !== "object") return false;
|
||||
const role = (message as { role?: unknown }).role;
|
||||
return role === "assistant";
|
||||
}
|
||||
|
||||
export default function (pi: ExtensionAPI) {
|
||||
let agentStartMs: number | null = null;
|
||||
|
||||
pi.on("agent_start", () => {
|
||||
agentStartMs = Date.now();
|
||||
});
|
||||
|
||||
pi.on("agent_end", (event, ctx) => {
|
||||
if (!ctx.hasUI) return;
|
||||
if (agentStartMs === null) return;
|
||||
|
||||
const elapsedMs = Date.now() - agentStartMs;
|
||||
agentStartMs = null;
|
||||
if (elapsedMs <= 0) return;
|
||||
|
||||
let input = 0;
|
||||
let output = 0;
|
||||
let cacheRead = 0;
|
||||
let cacheWrite = 0;
|
||||
let totalTokens = 0;
|
||||
|
||||
for (const message of event.messages) {
|
||||
if (!isAssistantMessage(message)) continue;
|
||||
input += message.usage.input || 0;
|
||||
output += message.usage.output || 0;
|
||||
cacheRead += message.usage.cacheRead || 0;
|
||||
cacheWrite += message.usage.cacheWrite || 0;
|
||||
totalTokens += message.usage.totalTokens || 0;
|
||||
}
|
||||
|
||||
if (output <= 0) return;
|
||||
|
||||
const elapsedSeconds = elapsedMs / 1000;
|
||||
const tokensPerSecond = output / elapsedSeconds;
|
||||
const message = `TPS ${tokensPerSecond.toFixed(1)} tok/s. out ${output.toLocaleString()}, in ${input.toLocaleString()}, cache r/w ${cacheRead.toLocaleString()}/${cacheWrite.toLocaleString()}, total ${totalTokens.toLocaleString()}, ${elapsedSeconds.toFixed(1)}s`;
|
||||
ctx.ui.notify(message, "info");
|
||||
});
|
||||
}
|
||||
2
.jae/git/.gitignore
vendored
Normal file
2
.jae/git/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
*
|
||||
!.gitignore
|
||||
2
.jae/npm/.gitignore
vendored
Normal file
2
.jae/npm/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
*
|
||||
!.gitignore
|
||||
54
.jae/prompts/cl.md
Normal file
54
.jae/prompts/cl.md
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
---
|
||||
description: Audit changelog entries before release
|
||||
---
|
||||
Audit changelog entries for all commits since the last release.
|
||||
|
||||
## Process
|
||||
|
||||
1. **Find the last release tag:**
|
||||
```bash
|
||||
git tag --sort=-version:refname | head -1
|
||||
```
|
||||
|
||||
2. **List all commits since that tag:**
|
||||
```bash
|
||||
git log <tag>..HEAD --oneline
|
||||
```
|
||||
|
||||
3. **Read each package's [Unreleased] section:**
|
||||
- packages/ai/CHANGELOG.md
|
||||
- packages/tui/CHANGELOG.md
|
||||
- packages/coding-agent/CHANGELOG.md
|
||||
|
||||
4. **For each commit, check:**
|
||||
- Skip: changelog updates, doc-only changes, release housekeeping
|
||||
- Skip: changes to generated model catalogs (for example `packages/ai/src/models.generated.ts`) unless accompanied by an intentional product-facing change in non-generated source/docs.
|
||||
- Determine which package(s) the commit affects (use `git show <hash> --stat`)
|
||||
- Verify a changelog entry exists in the affected package(s)
|
||||
- For external contributions (PRs), verify format: `Description ([#N](url) by [@user](url))`
|
||||
|
||||
5. **Cross-package duplication rule:**
|
||||
Changes in `ai`, `agent` or `tui` that affect end users should be duplicated to `coding-agent` changelog, since coding-agent is the user-facing package that depends on them.
|
||||
|
||||
6. **Add New Features section after changelog fixes:**
|
||||
- Insert a `### New Features` section at the start of `## [Unreleased]` in `packages/coding-agent/CHANGELOG.md`.
|
||||
- Propose the top new features to the user for confirmation before writing them.
|
||||
- Link to relevant docs and sections whenever possible.
|
||||
|
||||
7. **Report:**
|
||||
- List commits with missing entries
|
||||
- List entries that need cross-package duplication
|
||||
- Add any missing entries directly
|
||||
|
||||
## Changelog Format Reference
|
||||
|
||||
Sections (in order):
|
||||
- `### Breaking Changes` - API changes requiring migration
|
||||
- `### Added` - New features
|
||||
- `### Changed` - Changes to existing functionality
|
||||
- `### Fixed` - Bug fixes
|
||||
- `### Removed` - Removed features
|
||||
|
||||
Attribution:
|
||||
- Internal: `Fixed foo ([#123](https://github.com/jae/agent-jae-cli/issues/123))`
|
||||
- External: `Added bar ([#456](https://github.com/jae/agent-jae-cli/pull/456) by [@user](https://github.com/user))`
|
||||
24
.jae/prompts/is.md
Normal file
24
.jae/prompts/is.md
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
---
|
||||
description: Analyze GitHub issues (bugs or feature requests)
|
||||
---
|
||||
Analyze GitHub issue(s): $ARGUMENTS
|
||||
|
||||
For each issue:
|
||||
|
||||
1. Add the `inprogress` label to the issue via GitHub CLI before analysis starts. If adding the label fails, report that explicitly and continue.
|
||||
2. Read the issue in full, including all comments and linked issues/PRs.
|
||||
3. Do not trust analysis written in the issue. Independently verify behavior and derive your own analysis from the code and execution path.
|
||||
|
||||
4. **For bugs**:
|
||||
- Ignore any root cause analysis in the issue (likely wrong)
|
||||
- Read all related code files in full (no truncation)
|
||||
- Trace the code path and identify the actual root cause
|
||||
- Propose a fix
|
||||
|
||||
5. **For feature requests**:
|
||||
- Do not trust implementation proposals in the issue without verification
|
||||
- Read all related code files in full (no truncation)
|
||||
- Propose the most concise implementation approach
|
||||
- List affected files and changes needed
|
||||
|
||||
Do NOT implement unless explicitly asked. Analyze and propose only.
|
||||
40
.jae/prompts/pr.md
Normal file
40
.jae/prompts/pr.md
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
---
|
||||
description: Review PRs from URLs with structured issue and code analysis
|
||||
---
|
||||
You are given one or more GitHub PR URLs: $@
|
||||
|
||||
For each PR URL, do the following in order:
|
||||
1. Add the `inprogress` label to the PR via GitHub CLI before analysis starts. If adding the label fails, report that explicitly and continue.
|
||||
2. Read the PR page in full. Include description, all comments, all commits, and all changed files.
|
||||
3. Identify any linked issues referenced in the PR body, comments, commit messages, or cross links. Read each issue in full, including all comments.
|
||||
4. Analyze the PR diff. Read all relevant code files in full with no truncation from the current main branch and compare against the diff. Do not fetch PR file blobs unless a file is missing on main or the diff context is insufficient. Include related code paths that are not in the diff but are required to validate behavior.
|
||||
5. Check for a changelog entry in the relevant `packages/*/CHANGELOG.md` files. Report whether an entry exists. If missing, state that a changelog entry is required before merge and that you will add it if the user decides to merge. Follow the changelog format rules in AGENTS.md. Verify:
|
||||
- Entry uses correct section (`### Breaking Changes`, `### Added`, `### Fixed`, etc.)
|
||||
- External contributions include PR link and author: `Fixed foo ([#123](https://github.com/jae/agent-jae-cli/pull/123) by [@user](https://github.com/user))`
|
||||
- Breaking changes are in `### Breaking Changes`, not just `### Fixed`
|
||||
6. Check if packages/coding-agent/README.md, packages/coding-agent/docs/*.md, packages/coding-agent/examples/**/*.md require modification. This is usually the case when existing features have been changed, or new features have been added.
|
||||
7. Provide a structured review with these sections:
|
||||
- Good: solid choices or improvements
|
||||
- Bad: concrete issues, regressions, missing tests, or risks
|
||||
- Ugly: subtle or high impact problems
|
||||
8. Add Questions or Assumptions if anything is unclear.
|
||||
9. Add Change summary and Tests.
|
||||
|
||||
Output format per PR:
|
||||
PR: <url>
|
||||
Changelog:
|
||||
- ...
|
||||
Good:
|
||||
- ...
|
||||
Bad:
|
||||
- ...
|
||||
Ugly:
|
||||
- ...
|
||||
Questions or Assumptions:
|
||||
- ...
|
||||
Change summary:
|
||||
- ...
|
||||
Tests:
|
||||
- ...
|
||||
|
||||
If no issues are found, say so under Bad and Ugly.
|
||||
30
.jae/prompts/wr.md
Normal file
30
.jae/prompts/wr.md
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
---
|
||||
description: Finish the current task end-to-end with changelog, commit, and push
|
||||
---
|
||||
Wrap it.
|
||||
|
||||
Additional instructions: $ARGUMENTS
|
||||
|
||||
Determine context from the conversation history first.
|
||||
|
||||
Rules for context detection:
|
||||
- If the conversation already mentions a GitHub issue or PR, use that existing context.
|
||||
- If the work came from `/is` or `/pr`, assume the issue or PR context is already known from the conversation and from the analysis work already done.
|
||||
- If there is no GitHub issue or PR in the conversation history, treat this as non-GitHub work.
|
||||
|
||||
Unless I explicitly override something in this request, do the following in order:
|
||||
|
||||
1. Add or update the relevant package changelog entry under `## [Unreleased]` using the repo changelog rules.
|
||||
2. If this task is tied to a GitHub issue or PR and a final issue or PR comment has not already been posted in this session, draft it in my tone, preview it, and post exactly one final comment.
|
||||
3. Commit only files you changed in this session.
|
||||
4. If this task is tied to exactly one GitHub issue, include `closes #<issue>` in the commit message. If it is tied to multiple issues, stop and ask which one to use. If it is not tied to any issue, do not include `closes #` or `fixes #` in the commit message.
|
||||
5. Check the current git branch. If it is not `main`, stop and ask what to do. Do not push from another branch unless I explicitly say so.
|
||||
6. Push the current branch.
|
||||
|
||||
Constraints:
|
||||
- Never stage unrelated files.
|
||||
- Never use `git add .` or `git add -A`.
|
||||
- Run required checks before committing if code changed.
|
||||
- Do not open a PR unless I explicitly ask.
|
||||
- If this is not GitHub issue or PR work, do not post a GitHub comment.
|
||||
- If a final issue or PR comment was already posted in this session, do not post another one unless I explicitly ask.
|
||||
234
AGENTS.md
Normal file
234
AGENTS.md
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
# Development Rules
|
||||
|
||||
## First Message
|
||||
If the user did not give you a concrete task in their first message,
|
||||
read README.md, then ask which module(s) to work on. Based on the answer, read the relevant README.md files in parallel.
|
||||
- packages/ai/README.md
|
||||
- packages/tui/README.md
|
||||
- packages/agent/README.md
|
||||
- packages/coding-agent/README.md
|
||||
- packages/mom/README.md
|
||||
- packages/pods/README.md
|
||||
- packages/web-ui/README.md
|
||||
|
||||
## Code Quality
|
||||
- No `any` types unless absolutely necessary
|
||||
- Check node_modules for external API type definitions instead of guessing
|
||||
- **NEVER use inline imports** - no `await import("./foo.js")`, no `import("pkg").Type` in type positions, no dynamic imports for types. Always use standard top-level imports.
|
||||
- NEVER remove or downgrade code to fix type errors from outdated dependencies; upgrade the dependency instead
|
||||
- Always ask before removing functionality or code that appears to be intentional
|
||||
- Never hardcode key checks with, eg. `matchesKey(keyData, "ctrl+x")`. All keybindings must be configurable. Add default to matching object (`DEFAULT_EDITOR_KEYBINDINGS` or `DEFAULT_APP_KEYBINDINGS`)
|
||||
|
||||
## Commands
|
||||
- After code changes (not documentation changes): `npm run check` (get full output, no tail). Fix all errors, warnings, and infos before committing.
|
||||
- Note: `npm run check` does not run tests.
|
||||
- NEVER run: `npm run dev`, `npm run build`, `npm test`
|
||||
- Only run specific tests if user instructs: `npx tsx ../../node_modules/vitest/dist/cli.js --run test/specific.test.ts`
|
||||
- Run tests from the package root, not the repo root.
|
||||
- If you create or modify a test file, you MUST run that test file and iterate until it passes.
|
||||
- When writing tests, run them, identify issues in either the test or implementation, and iterate until fixed.
|
||||
- NEVER commit unless user asks
|
||||
|
||||
## GitHub Issues
|
||||
When reading issues:
|
||||
- Always read all comments on the issue
|
||||
- Use this command to get everything in one call:
|
||||
```bash
|
||||
gh issue view <number> --json title,body,comments,labels,state
|
||||
```
|
||||
|
||||
## OSS Weekend
|
||||
- If the user says `enable OSS weekend mode until X`, run `node scripts/oss-weekend.mjs --mode=close --end-date=YYYY-MM-DD --git` with the requested end date
|
||||
- If the user says `end OSS weekend mode`, run `node scripts/oss-weekend.mjs --mode=open --git`
|
||||
- The script updates `README.md`, `packages/coding-agent/README.md`, and `.github/oss-weekend.json`
|
||||
- With `--git`, the script stages only those OSS weekend files, commits them, and pushes them
|
||||
- During OSS weekend, `.github/workflows/oss-weekend-issues.yml` auto-closes new issues from non-maintainers, and `.github/workflows/pr-gate.yml` auto-closes PRs from approved non-maintainers with the weekend message
|
||||
|
||||
When creating issues:
|
||||
- Add `pkg:*` labels to indicate which package(s) the issue affects
|
||||
- Available labels: `pkg:agent`, `pkg:ai`, `pkg:coding-agent`, `pkg:mom`, `pkg:pods`, `pkg:tui`, `pkg:web-ui`
|
||||
- If an issue spans multiple packages, add all relevant labels
|
||||
|
||||
When posting issue/PR comments:
|
||||
- Write the full comment to a temp file and use `gh issue comment --body-file` or `gh pr comment --body-file`
|
||||
- Never pass multi-line markdown directly via `--body` in shell commands
|
||||
- Preview the exact comment text before posting
|
||||
- Post exactly one final comment unless the user explicitly asks for multiple comments
|
||||
- If a comment is malformed, delete it immediately, then post one corrected comment
|
||||
- Keep comments concise, technical, and in the user's tone
|
||||
|
||||
When closing issues via commit:
|
||||
- Include `fixes #<number>` or `closes #<number>` in the commit message
|
||||
- This automatically closes the issue when the commit is merged
|
||||
|
||||
## PR Workflow
|
||||
- Analyze PRs without pulling locally first
|
||||
- If the user approves: create a feature branch, pull PR, rebase on main, apply adjustments, commit, merge into main, push, close PR, and leave a comment in the user's tone
|
||||
- You never open PRs yourself. We work in feature branches until everything is according to the user's requirements, then merge into main, and push.
|
||||
|
||||
## Tools
|
||||
- GitHub CLI for issues/PRs
|
||||
- Add package labels to issues/PRs: pkg:agent, pkg:ai, pkg:coding-agent, pkg:mom, pkg:pods, pkg:tui, pkg:web-ui
|
||||
|
||||
## Testing pi Interactive Mode with tmux
|
||||
|
||||
To test pi's TUI in a controlled terminal environment:
|
||||
|
||||
```bash
|
||||
# Create tmux session with specific dimensions
|
||||
tmux new-session -d -s pi-test -x 80 -y 24
|
||||
|
||||
# Start pi from source
|
||||
tmux send-keys -t pi-test "cd /Users/badlogic/workspaces/jae-mono && ./jae-test.sh" Enter
|
||||
|
||||
# Wait for startup, then capture output
|
||||
sleep 3 && tmux capture-pane -t pi-test -p
|
||||
|
||||
# Send input
|
||||
tmux send-keys -t pi-test "your prompt here" Enter
|
||||
|
||||
# Send special keys
|
||||
tmux send-keys -t pi-test Escape
|
||||
tmux send-keys -t pi-test C-o # ctrl+o
|
||||
|
||||
# Cleanup
|
||||
tmux kill-session -t pi-test
|
||||
```
|
||||
|
||||
## Style
|
||||
- Keep answers short and concise
|
||||
- No emojis in commits, issues, PR comments, or code
|
||||
- No fluff or cheerful filler text
|
||||
- Technical prose only, be kind but direct (e.g., "Thanks @user" not "Thanks so much @user!")
|
||||
|
||||
## Changelog
|
||||
Location: `packages/*/CHANGELOG.md` (each package has its own)
|
||||
|
||||
### Format
|
||||
Use these sections under `## [Unreleased]`:
|
||||
- `### Breaking Changes` - API changes requiring migration
|
||||
- `### Added` - New features
|
||||
- `### Changed` - Changes to existing functionality
|
||||
- `### Fixed` - Bug fixes
|
||||
- `### Removed` - Removed features
|
||||
|
||||
### Rules
|
||||
- Before adding entries, read the full `[Unreleased]` section to see which subsections already exist
|
||||
- New entries ALWAYS go under `## [Unreleased]` section
|
||||
- Append to existing subsections (e.g., `### Fixed`), do not create duplicates
|
||||
- NEVER modify already-released version sections (e.g., `## [0.12.2]`)
|
||||
- Each version section is immutable once released
|
||||
|
||||
### Attribution
|
||||
- **Internal changes (from issues)**: `Fixed foo bar ([#123](https://github.com/jae/agent-jae-cli/issues/123))`
|
||||
- **External contributions**: `Added feature X ([#456](https://github.com/jae/agent-jae-cli/pull/456) by [@username](https://github.com/username))`
|
||||
|
||||
## Adding a New LLM Provider (packages/ai)
|
||||
|
||||
Adding a new provider requires changes across multiple files:
|
||||
|
||||
### 1. Core Types (`packages/ai/src/types.ts`)
|
||||
- Add API identifier to `Api` type union (e.g., `"bedrock-converse-stream"`)
|
||||
- Create options interface extending `StreamOptions`
|
||||
- Add mapping to `ApiOptionsMap`
|
||||
- Add provider name to `KnownProvider` type union
|
||||
|
||||
### 2. Provider Implementation (`packages/ai/src/providers/`)
|
||||
Create provider file exporting:
|
||||
- `stream<Provider>()` function returning `AssistantMessageEventStream`
|
||||
- `streamSimple<Provider>()` for `SimpleStreamOptions` mapping
|
||||
- Provider-specific options interface
|
||||
- Message/tool conversion functions
|
||||
- Response parsing emitting standardized events (`text`, `tool_call`, `thinking`, `usage`, `stop`)
|
||||
|
||||
### 3. Provider Exports and Lazy Registration
|
||||
- Add a package subpath export in `packages/ai/package.json` pointing at `./dist/providers/<provider>.js`
|
||||
- Add `export type` re-exports in `packages/ai/src/index.ts` for provider option types that should remain available from the root entry
|
||||
- Register the provider in `packages/ai/src/providers/register-builtins.ts` via lazy loader wrappers, do not statically import provider implementation modules there
|
||||
- Add credential detection in `packages/ai/src/env-api-keys.ts`
|
||||
|
||||
### 4. Model Generation (`packages/ai/scripts/generate-models.ts`)
|
||||
- Add logic to fetch/parse models from provider source
|
||||
- Map to standardized `Model` interface
|
||||
|
||||
### 5. Tests (`packages/ai/test/`)
|
||||
Add provider to: `stream.test.ts`, `tokens.test.ts`, `abort.test.ts`, `empty.test.ts`, `context-overflow.test.ts`, `image-limits.test.ts`, `unicode-surrogate.test.ts`, `tool-call-without-result.test.ts`, `image-tool-result.test.ts`, `total-tokens.test.ts`, `cross-provider-handoff.test.ts`.
|
||||
|
||||
For `cross-provider-handoff.test.ts`, add at least one provider/model pair. If the provider exposes multiple model families (for example GPT and Claude), add at least one pair per family.
|
||||
|
||||
For non-standard auth, create utility (e.g., `bedrock-utils.ts`) with credential detection.
|
||||
|
||||
### 6. Coding Agent (`packages/coding-agent/`)
|
||||
- `src/core/model-resolver.ts`: Add default model ID to `DEFAULT_MODELS`
|
||||
- `src/cli/args.ts`: Add env var documentation
|
||||
- `README.md`: Add provider setup instructions
|
||||
|
||||
### 7. Documentation
|
||||
- `packages/ai/README.md`: Add to providers table, document options/auth, add env vars
|
||||
- `packages/ai/CHANGELOG.md`: Add entry under `## [Unreleased]`
|
||||
|
||||
## Releasing
|
||||
|
||||
**Lockstep versioning**: All packages always share the same version number. Every release updates all packages together.
|
||||
|
||||
**Version semantics** (no major releases):
|
||||
- `patch`: Bug fixes and new features
|
||||
- `minor`: API breaking changes
|
||||
|
||||
### Steps
|
||||
|
||||
1. **Update CHANGELOGs**: Ensure all changes since last release are documented in the `[Unreleased]` section of each affected package's CHANGELOG.md
|
||||
|
||||
2. **Run release script**:
|
||||
```bash
|
||||
npm run release:patch # Fixes and additions
|
||||
npm run release:minor # API breaking changes
|
||||
```
|
||||
|
||||
The script handles: version bump, CHANGELOG finalization, commit, tag, publish, and adding new `[Unreleased]` sections.
|
||||
|
||||
## **CRITICAL** Tool Usage Rules **CRITICAL**
|
||||
- NEVER use sed/cat to read a file or a range of a file. Always use the read tool (use offset + limit for ranged reads).
|
||||
- You MUST read every file you modify in full before editing.
|
||||
|
||||
## **CRITICAL** Git Rules for Parallel Agents **CRITICAL**
|
||||
|
||||
Multiple agents may work on different files in the same worktree simultaneously. You MUST follow these rules:
|
||||
|
||||
### Committing
|
||||
- **ONLY commit files YOU changed in THIS session**
|
||||
- ALWAYS include `fixes #<number>` or `closes #<number>` in the commit message when there is a related issue or PR
|
||||
- NEVER use `git add -A` or `git add .` - these sweep up changes from other agents
|
||||
- ALWAYS use `git add <specific-file-paths>` listing only files you modified
|
||||
- Before committing, run `git status` and verify you are only staging YOUR files
|
||||
- Track which files you created/modified/deleted during the session
|
||||
|
||||
### Forbidden Git Operations
|
||||
These commands can destroy other agents' work:
|
||||
- `git reset --hard` - destroys uncommitted changes
|
||||
- `git checkout .` - destroys uncommitted changes
|
||||
- `git clean -fd` - deletes untracked files
|
||||
- `git stash` - stashes ALL changes including other agents' work
|
||||
- `git add -A` / `git add .` - stages other agents' uncommitted work
|
||||
- `git commit --no-verify` - bypasses required checks and is never allowed
|
||||
|
||||
### Safe Workflow
|
||||
```bash
|
||||
# 1. Check status first
|
||||
git status
|
||||
|
||||
# 2. Add ONLY your specific files
|
||||
git add packages/ai/src/providers/transform-messages.ts
|
||||
git add packages/ai/CHANGELOG.md
|
||||
|
||||
# 3. Commit
|
||||
git commit -m "fix(ai): description"
|
||||
|
||||
# 4. Push (pull --rebase if needed, but NEVER reset/checkout)
|
||||
git pull --rebase && git push
|
||||
```
|
||||
|
||||
### If Rebase Conflicts Occur
|
||||
- Resolve conflicts in YOUR files only
|
||||
- If conflict is in a file you didn't modify, abort and ask the user
|
||||
- NEVER force push
|
||||
42
CONTRIBUTING.md
Normal file
42
CONTRIBUTING.md
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# Contributing to pi
|
||||
|
||||
Thanks for wanting to contribute! This guide exists to save both of us time.
|
||||
|
||||
## The One Rule
|
||||
|
||||
**You must understand your code.** If you can't explain what your changes do and how they interact with the rest of the system, your PR will be closed.
|
||||
|
||||
Using AI to write code is fine. You can gain understanding by interrogating an agent with access to the codebase until you grasp all edge cases and effects of your changes. What's not fine is submitting agent-generated slop without that understanding.
|
||||
|
||||
If you use an agent, run it from the `jae-mono` root directory so it picks up `AGENTS.md` automatically. Your agent must follow the rules and guidelines in that file.
|
||||
|
||||
## First-Time Contributors
|
||||
|
||||
We use an approval gate for new contributors:
|
||||
|
||||
1. Open an issue describing what you want to change and why
|
||||
2. Keep it concise (if it doesn't fit on one screen, it's too long)
|
||||
3. Write in your own voice, at least for the intro
|
||||
4. A maintainer will comment `lgtm` if approved
|
||||
5. Once approved, you can submit PRs
|
||||
|
||||
This exists because AI makes it trivial to generate plausible-looking but low-quality contributions. The issue step lets us filter early.
|
||||
|
||||
## Before Submitting a PR
|
||||
|
||||
```bash
|
||||
npm run check # must pass with no errors
|
||||
./test.sh # must pass
|
||||
```
|
||||
|
||||
Do not edit `CHANGELOG.md`. Changelog entries are added by maintainers.
|
||||
|
||||
If you're adding a new provider to `packages/ai`, see `AGENTS.md` for required tests.
|
||||
|
||||
## Philosophy
|
||||
|
||||
pi's core is minimal. If your feature doesn't belong in the core, it should be an extension. PRs that bloat the core will likely be rejected.
|
||||
|
||||
## Questions?
|
||||
|
||||
Open an issue or ask on [Discord](https://discord.com/invite/nKXTsAcmbT).
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2025 Mario Zechner
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
62
README.md
Normal file
62
README.md
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
<!-- OSS_WEEKEND_START -->
|
||||
# 🏖️ OSS Weekend
|
||||
|
||||
**Issue tracker reopens Monday, March 30, 2026.**
|
||||
|
||||
OSS weekend runs Sunday, March 22, 2026 through Monday, March 30, 2026. New issues are auto-closed during this time. For support, join [Discord](https://discord.com/invite/3cU7Bz4UPx).
|
||||
<!-- OSS_WEEKEND_END -->
|
||||
|
||||
---
|
||||
|
||||
<p align="center">
|
||||
<a href="https://shittycodingagent.ai">
|
||||
<img src="https://shittycodingagent.ai/logo.svg" alt="pi logo" width="128">
|
||||
</a>
|
||||
</p>
|
||||
<p align="center">
|
||||
<a href="https://discord.com/invite/3cU7Bz4UPx"><img alt="Discord" src="https://img.shields.io/badge/discord-community-5865F2?style=flat-square&logo=discord&logoColor=white" /></a>
|
||||
<a href="https://github.com/jae/agent-jae-cli/actions/workflows/ci.yml"><img alt="Build status" src="https://img.shields.io/github/actions/workflow/status/jae/agent-jae-cli/ci.yml?style=flat-square&branch=main" /></a>
|
||||
</p>
|
||||
<p align="center">
|
||||
<a href="https://pi.dev">pi.dev</a> domain graciously donated by
|
||||
<br /><br />
|
||||
<a href="https://exe.dev"><img src="packages/coding-agent/docs/images/exy.png" alt="Exy mascot" width="48" /><br />exe.dev</a>
|
||||
</p>
|
||||
|
||||
# Pi Monorepo
|
||||
|
||||
> **Looking for the pi coding agent?** See **[packages/coding-agent](packages/coding-agent)** for installation and usage.
|
||||
|
||||
Tools for building AI agents and managing LLM deployments.
|
||||
|
||||
## Packages
|
||||
|
||||
| Package | Description |
|
||||
|---------|-------------|
|
||||
| **[@jaeswift/jae-ai](packages/ai)** | Unified multi-provider LLM API (OpenAI, Anthropic, Google, etc.) |
|
||||
| **[@jaeswift/jae-agent-core](packages/agent)** | Agent runtime with tool calling and state management |
|
||||
| **[@jaeswift/jae-coding-agent](packages/coding-agent)** | Interactive coding agent CLI |
|
||||
| **[@jaeswift/jae-mom](packages/mom)** | Slack bot that delegates messages to the pi coding agent |
|
||||
| **[@jaeswift/jae-tui](packages/tui)** | Terminal UI library with differential rendering |
|
||||
| **[@jaeswift/jae-web-ui](packages/web-ui)** | Web components for AI chat interfaces |
|
||||
| **[@jaeswift/jae-pods](packages/pods)** | CLI for managing vLLM deployments on GPU pods |
|
||||
|
||||
## Contributing
|
||||
|
||||
See [CONTRIBUTING.md](CONTRIBUTING.md) for contribution guidelines and [AGENTS.md](AGENTS.md) for project-specific rules (for both humans and agents).
|
||||
|
||||
## Development
|
||||
|
||||
```bash
|
||||
npm install # Install all dependencies
|
||||
npm run build # Build all packages
|
||||
npm run check # Lint, format, and type check
|
||||
./test.sh # Run tests (skips LLM-dependent tests without API keys)
|
||||
./jae-test.sh # Run pi from sources (must be run from repo root)
|
||||
```
|
||||
|
||||
> **Note:** `npm run check` requires `npm run build` to be run first. The web-ui package uses `tsc` which needs compiled `.d.ts` files from dependencies.
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
41
biome.json
Normal file
41
biome.json
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
{
|
||||
"$schema": "https://biomejs.dev/schemas/2.3.5/schema.json",
|
||||
"linter": {
|
||||
"enabled": true,
|
||||
"rules": {
|
||||
"recommended": true,
|
||||
"style": {
|
||||
"noNonNullAssertion": "off",
|
||||
"useConst": "error",
|
||||
"useNodejsImportProtocol": "off"
|
||||
},
|
||||
"suspicious": {
|
||||
"noExplicitAny": "off",
|
||||
"noControlCharactersInRegex": "off",
|
||||
"noEmptyInterface": "off"
|
||||
}
|
||||
}
|
||||
},
|
||||
"formatter": {
|
||||
"enabled": true,
|
||||
"formatWithErrors": false,
|
||||
"indentStyle": "tab",
|
||||
"indentWidth": 3,
|
||||
"lineWidth": 120
|
||||
},
|
||||
"files": {
|
||||
"includes": [
|
||||
"packages/*/src/**/*.ts",
|
||||
"packages/*/test/**/*.ts",
|
||||
"packages/coding-agent/examples/**/*.ts",
|
||||
"packages/web-ui/src/**/*.ts",
|
||||
"packages/web-ui/example/**/*.ts",
|
||||
"!**/node_modules/**/*",
|
||||
"!**/test-sessions.ts",
|
||||
"!**/models.generated.ts",
|
||||
"!packages/web-ui/src/app.css",
|
||||
"!packages/mom/data/**/*",
|
||||
"!!**/node_modules"
|
||||
]
|
||||
}
|
||||
}
|
||||
12
jae-mono.code-workspace
Normal file
12
jae-mono.code-workspace
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
{
|
||||
"folders": [
|
||||
{
|
||||
"name": "pi-mono",
|
||||
"path": "."
|
||||
},
|
||||
{
|
||||
"path": "../../moms"
|
||||
}
|
||||
],
|
||||
"settings": {}
|
||||
}
|
||||
56
jae-test.sh
Executable file
56
jae-test.sh
Executable file
|
|
@ -0,0 +1,56 @@
|
|||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
# Check for --no-env flag
|
||||
NO_ENV=false
|
||||
ARGS=()
|
||||
for arg in "$@"; do
|
||||
if [[ "$arg" == "--no-env" ]]; then
|
||||
NO_ENV=true
|
||||
else
|
||||
ARGS+=("$arg")
|
||||
fi
|
||||
done
|
||||
|
||||
if [[ "$NO_ENV" == "true" ]]; then
|
||||
# Unset API keys (see packages/ai/src/env-api-keys.ts)
|
||||
unset ANTHROPIC_API_KEY
|
||||
unset ANTHROPIC_OAUTH_TOKEN
|
||||
unset OPENAI_API_KEY
|
||||
unset GEMINI_API_KEY
|
||||
unset GROQ_API_KEY
|
||||
unset CEREBRAS_API_KEY
|
||||
unset XAI_API_KEY
|
||||
unset OPENROUTER_API_KEY
|
||||
unset ZAI_API_KEY
|
||||
unset MISTRAL_API_KEY
|
||||
unset MINIMAX_API_KEY
|
||||
unset MINIMAX_CN_API_KEY
|
||||
unset AI_GATEWAY_API_KEY
|
||||
unset OPENCODE_API_KEY
|
||||
unset COPILOT_GITHUB_TOKEN
|
||||
unset GH_TOKEN
|
||||
unset GITHUB_TOKEN
|
||||
unset GOOGLE_APPLICATION_CREDENTIALS
|
||||
unset GOOGLE_CLOUD_PROJECT
|
||||
unset GCLOUD_PROJECT
|
||||
unset GOOGLE_CLOUD_LOCATION
|
||||
unset AWS_PROFILE
|
||||
unset AWS_ACCESS_KEY_ID
|
||||
unset AWS_SECRET_ACCESS_KEY
|
||||
unset AWS_SESSION_TOKEN
|
||||
unset AWS_REGION
|
||||
unset AWS_DEFAULT_REGION
|
||||
unset AWS_BEARER_TOKEN_BEDROCK
|
||||
unset AWS_CONTAINER_CREDENTIALS_RELATIVE_URI
|
||||
unset AWS_CONTAINER_CREDENTIALS_FULL_URI
|
||||
unset AWS_WEB_IDENTITY_TOKEN_FILE
|
||||
unset AZURE_OPENAI_API_KEY
|
||||
unset AZURE_OPENAI_BASE_URL
|
||||
unset AZURE_OPENAI_RESOURCE_NAME
|
||||
echo "Running without API keys..."
|
||||
fi
|
||||
|
||||
npx tsx "$SCRIPT_DIR/packages/coding-agent/src/cli.ts" ${ARGS[@]+"${ARGS[@]}"}
|
||||
8824
package-lock.json
generated
Normal file
8824
package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load diff
61
package.json
Normal file
61
package.json
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
{
|
||||
"name": "jae-monorepo",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"workspaces": [
|
||||
"packages/*",
|
||||
"packages/web-ui/example",
|
||||
"packages/coding-agent/examples/extensions/with-deps",
|
||||
"packages/coding-agent/examples/extensions/custom-provider-anthropic",
|
||||
"packages/coding-agent/examples/extensions/custom-provider-gitlab-duo",
|
||||
"packages/coding-agent/examples/extensions/custom-provider-qwen-cli"
|
||||
],
|
||||
"scripts": {
|
||||
"clean": "npm run clean --workspaces",
|
||||
"build": "cd packages/tui && npm run build && cd ../ai && npm run build && cd ../agent && npm run build && cd ../coding-agent && npm run build && cd ../mom && npm run build && cd ../web-ui && npm run build && cd ../pods && npm run build",
|
||||
"dev": "concurrently --names \"ai,agent,coding-agent,mom,web-ui,tui\" --prefix-colors \"cyan,yellow,red,white,green,magenta\" \"cd packages/ai && npm run dev\" \"cd packages/agent && npm run dev\" \"cd packages/coding-agent && npm run dev\" \"cd packages/mom && npm run dev\" \"cd packages/web-ui && npm run dev\" \"cd packages/tui && npm run dev\"",
|
||||
"dev:tsc": "concurrently --names \"ai,web-ui\" --prefix-colors \"cyan,green\" \"cd packages/ai && npm run dev:tsc\" \"cd packages/web-ui && npm run dev:tsc\"",
|
||||
"check": "biome check --write --error-on-warnings . && tsgo --noEmit && npm run check:browser-smoke && cd packages/web-ui && npm run check",
|
||||
"check:browser-smoke": "node scripts/check-browser-smoke.mjs",
|
||||
"profile:tui": "node scripts/profile-coding-agent-node.mjs --mode tui",
|
||||
"profile:rpc": "node scripts/profile-coding-agent-node.mjs --mode rpc",
|
||||
"test": "npm run test --workspaces --if-present",
|
||||
"version:patch": "npm version patch -ws --no-git-tag-version && node scripts/sync-versions.js && shx rm -rf node_modules packages/*/node_modules package-lock.json && npm install",
|
||||
"version:minor": "npm version minor -ws --no-git-tag-version && node scripts/sync-versions.js && shx rm -rf node_modules packages/*/node_modules package-lock.json && npm install",
|
||||
"version:major": "npm version major -ws --no-git-tag-version && node scripts/sync-versions.js && shx rm -rf node_modules packages/*/node_modules package-lock.json && npm install",
|
||||
"version:set": "npm version -ws",
|
||||
"prepublishOnly": "npm run clean && npm run build && npm run check",
|
||||
"publish": "npm run prepublishOnly && npm publish -ws --access public",
|
||||
"publish:dry": "npm run prepublishOnly && npm publish -ws --access public --dry-run",
|
||||
"release:patch": "node scripts/release.mjs patch",
|
||||
"release:minor": "node scripts/release.mjs minor",
|
||||
"release:major": "node scripts/release.mjs major",
|
||||
"prepare": "husky"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@biomejs/biome": "2.3.5",
|
||||
"@types/node": "^22.10.5",
|
||||
"@typescript/native-preview": "7.0.0-dev.20260120.1",
|
||||
"concurrently": "^9.2.1",
|
||||
"husky": "^9.1.7",
|
||||
"tsx": "^4.20.3",
|
||||
"typescript": "^5.9.2",
|
||||
"shx": "^0.4.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
},
|
||||
"version": "0.0.3",
|
||||
"dependencies": {
|
||||
"@mariozechner/jiti": "^2.6.5",
|
||||
"@jaeswift/jae-coding-agent": "^0.30.2",
|
||||
"get-east-asian-width": "^1.4.0"
|
||||
},
|
||||
"overrides": {
|
||||
"rimraf": "6.1.2",
|
||||
"fast-xml-parser": "5.3.8",
|
||||
"gaxios": {
|
||||
"rimraf": "6.1.2"
|
||||
}
|
||||
}
|
||||
}
|
||||
300
packages/agent/CHANGELOG.md
Normal file
300
packages/agent/CHANGELOG.md
Normal file
|
|
@ -0,0 +1,300 @@
|
|||
# Changelog
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.62.0] - 2026-03-23
|
||||
|
||||
## [0.61.1] - 2026-03-20
|
||||
|
||||
## [0.61.0] - 2026-03-20
|
||||
|
||||
## [0.60.0] - 2026-03-18
|
||||
|
||||
## [0.59.0] - 2026-03-17
|
||||
|
||||
## [0.58.4] - 2026-03-16
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed steering messages to wait until the current assistant message's tool-call batch fully finishes instead of skipping pending tool calls.
|
||||
|
||||
## [0.58.3] - 2026-03-15
|
||||
|
||||
## [0.58.2] - 2026-03-15
|
||||
|
||||
## [0.58.1] - 2026-03-14
|
||||
|
||||
## [0.58.0] - 2026-03-14
|
||||
|
||||
### Added
|
||||
|
||||
- Added `beforeToolCall` and `afterToolCall` hooks to `AgentOptions` and `AgentLoopConfig` for preflight blocking and post-execution tool result mutation.
|
||||
|
||||
### Changed
|
||||
|
||||
- Added configurable tool execution mode to `Agent` and `agentLoop` via `toolExecution: "parallel" | "sequential"`, with `parallel` as the default. Parallel mode preflights tool calls sequentially, executes allowed tools concurrently, and emits final tool results in assistant source order.
|
||||
|
||||
## [0.57.1] - 2026-03-07
|
||||
|
||||
## [0.57.0] - 2026-03-07
|
||||
|
||||
## [0.56.3] - 2026-03-06
|
||||
|
||||
## [0.56.2] - 2026-03-05
|
||||
|
||||
## [0.56.1] - 2026-03-05
|
||||
|
||||
## [0.56.0] - 2026-03-04
|
||||
|
||||
## [0.55.4] - 2026-03-02
|
||||
|
||||
## [0.55.3] - 2026-02-27
|
||||
|
||||
## [0.55.2] - 2026-02-27
|
||||
|
||||
## [0.55.1] - 2026-02-26
|
||||
|
||||
## [0.55.0] - 2026-02-24
|
||||
|
||||
## [0.54.2] - 2026-02-23
|
||||
|
||||
## [0.54.1] - 2026-02-22
|
||||
|
||||
## [0.54.0] - 2026-02-19
|
||||
|
||||
## [0.53.1] - 2026-02-19
|
||||
|
||||
## [0.53.0] - 2026-02-17
|
||||
|
||||
## [0.52.12] - 2026-02-13
|
||||
|
||||
### Added
|
||||
|
||||
- Added `transport` to `AgentOptions` and `AgentLoopConfig` forwarding, allowing stream transport preference (`"sse"`, `"websocket"`, `"auto"`) to flow into provider calls.
|
||||
|
||||
## [0.52.11] - 2026-02-13
|
||||
|
||||
## [0.52.10] - 2026-02-12
|
||||
|
||||
## [0.52.9] - 2026-02-08
|
||||
|
||||
## [0.52.8] - 2026-02-07
|
||||
|
||||
## [0.52.7] - 2026-02-06
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed `continue()` to resume queued steering/follow-up messages when context currently ends in an assistant message, and preserved one-at-a-time steering ordering during assistant-tail resumes ([#1312](https://github.com/jae/agent-jae-cli/pull/1312) by [@ferologics](https://github.com/ferologics))
|
||||
|
||||
## [0.52.6] - 2026-02-05
|
||||
|
||||
## [0.52.5] - 2026-02-05
|
||||
|
||||
## [0.52.4] - 2026-02-05
|
||||
|
||||
## [0.52.3] - 2026-02-05
|
||||
|
||||
## [0.52.2] - 2026-02-05
|
||||
|
||||
## [0.52.1] - 2026-02-05
|
||||
|
||||
## [0.52.0] - 2026-02-05
|
||||
|
||||
## [0.51.6] - 2026-02-04
|
||||
|
||||
## [0.51.5] - 2026-02-04
|
||||
|
||||
## [0.51.4] - 2026-02-03
|
||||
|
||||
## [0.51.3] - 2026-02-03
|
||||
|
||||
## [0.51.2] - 2026-02-03
|
||||
|
||||
## [0.51.1] - 2026-02-02
|
||||
|
||||
## [0.51.0] - 2026-02-01
|
||||
|
||||
## [0.50.9] - 2026-02-01
|
||||
|
||||
## [0.50.8] - 2026-02-01
|
||||
|
||||
### Added
|
||||
|
||||
- Added `maxRetryDelayMs` option to `AgentOptions` to cap server-requested retry delays. Passed through to the underlying stream function. ([#1123](https://github.com/jae/agent-jae-cli/issues/1123))
|
||||
|
||||
## [0.50.7] - 2026-01-31
|
||||
|
||||
## [0.50.6] - 2026-01-30
|
||||
|
||||
## [0.50.5] - 2026-01-30
|
||||
|
||||
## [0.50.3] - 2026-01-29
|
||||
|
||||
## [0.50.2] - 2026-01-29
|
||||
|
||||
## [0.50.1] - 2026-01-26
|
||||
|
||||
## [0.50.0] - 2026-01-26
|
||||
|
||||
## [0.49.3] - 2026-01-22
|
||||
|
||||
## [0.49.2] - 2026-01-19
|
||||
|
||||
## [0.49.1] - 2026-01-18
|
||||
|
||||
## [0.49.0] - 2026-01-17
|
||||
|
||||
## [0.48.0] - 2026-01-16
|
||||
|
||||
## [0.47.0] - 2026-01-16
|
||||
|
||||
## [0.46.0] - 2026-01-15
|
||||
|
||||
## [0.45.7] - 2026-01-13
|
||||
|
||||
## [0.45.6] - 2026-01-13
|
||||
|
||||
## [0.45.5] - 2026-01-13
|
||||
|
||||
## [0.45.4] - 2026-01-13
|
||||
|
||||
## [0.45.3] - 2026-01-13
|
||||
|
||||
## [0.45.2] - 2026-01-13
|
||||
|
||||
## [0.45.1] - 2026-01-13
|
||||
|
||||
## [0.45.0] - 2026-01-13
|
||||
|
||||
## [0.44.0] - 2026-01-12
|
||||
|
||||
## [0.43.0] - 2026-01-11
|
||||
|
||||
## [0.42.5] - 2026-01-11
|
||||
|
||||
## [0.42.4] - 2026-01-10
|
||||
|
||||
## [0.42.3] - 2026-01-10
|
||||
|
||||
## [0.42.2] - 2026-01-10
|
||||
|
||||
## [0.42.1] - 2026-01-09
|
||||
|
||||
## [0.42.0] - 2026-01-09
|
||||
|
||||
## [0.41.0] - 2026-01-09
|
||||
|
||||
## [0.40.1] - 2026-01-09
|
||||
|
||||
## [0.40.0] - 2026-01-08
|
||||
|
||||
## [0.39.1] - 2026-01-08
|
||||
|
||||
## [0.39.0] - 2026-01-08
|
||||
|
||||
## [0.38.0] - 2026-01-08
|
||||
|
||||
### Added
|
||||
|
||||
- `thinkingBudgets` option on `Agent` and `AgentOptions` to customize token budgets per thinking level ([#529](https://github.com/jae/agent-jae-cli/pull/529) by [@melihmucuk](https://github.com/melihmucuk))
|
||||
|
||||
## [0.37.8] - 2026-01-07
|
||||
|
||||
## [0.37.7] - 2026-01-07
|
||||
|
||||
## [0.37.6] - 2026-01-06
|
||||
|
||||
## [0.37.5] - 2026-01-06
|
||||
|
||||
## [0.37.4] - 2026-01-06
|
||||
|
||||
## [0.37.3] - 2026-01-06
|
||||
|
||||
### Added
|
||||
|
||||
- `sessionId` option on `Agent` to forward session identifiers to LLM providers for session-based caching.
|
||||
|
||||
## [0.37.2] - 2026-01-05
|
||||
|
||||
## [0.37.1] - 2026-01-05
|
||||
|
||||
## [0.37.0] - 2026-01-05
|
||||
|
||||
### Fixed
|
||||
|
||||
- `minimal` thinking level now maps to `minimal` reasoning effort instead of being treated as `low`.
|
||||
|
||||
## [0.36.0] - 2026-01-05
|
||||
|
||||
## [0.35.0] - 2026-01-05
|
||||
|
||||
## [0.34.2] - 2026-01-04
|
||||
|
||||
## [0.34.1] - 2026-01-04
|
||||
|
||||
## [0.34.0] - 2026-01-04
|
||||
|
||||
## [0.33.0] - 2026-01-04
|
||||
|
||||
## [0.32.3] - 2026-01-03
|
||||
|
||||
## [0.32.2] - 2026-01-03
|
||||
|
||||
## [0.32.1] - 2026-01-03
|
||||
|
||||
## [0.32.0] - 2026-01-03
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- **Queue API replaced with steer/followUp**: The `queueMessage()` method has been split into two methods with different delivery semantics ([#403](https://github.com/jae/agent-jae-cli/issues/403)):
|
||||
- `steer(msg)`: Interrupts the agent mid-run. Delivered after current tool execution, skips remaining tools.
|
||||
- `followUp(msg)`: Waits until the agent finishes. Delivered only when there are no more tool calls or steering messages.
|
||||
- **Queue mode renamed**: `queueMode` option renamed to `steeringMode`. Added new `followUpMode` option. Both control whether messages are delivered one-at-a-time or all at once.
|
||||
- **AgentLoopConfig callbacks renamed**: `getQueuedMessages` split into `getSteeringMessages` and `getFollowUpMessages`.
|
||||
- **Agent methods renamed**:
|
||||
- `queueMessage()` → `steer()` and `followUp()`
|
||||
- `clearMessageQueue()` → `clearSteeringQueue()`, `clearFollowUpQueue()`, `clearAllQueues()`
|
||||
- `setQueueMode()`/`getQueueMode()` → `setSteeringMode()`/`getSteeringMode()` and `setFollowUpMode()`/`getFollowUpMode()`
|
||||
|
||||
### Fixed
|
||||
|
||||
- `prompt()` and `continue()` now throw if called while the agent is already streaming, preventing race conditions and corrupted state. Use `steer()` or `followUp()` to queue messages during streaming, or `await` the previous call.
|
||||
|
||||
## [0.31.1] - 2026-01-02
|
||||
|
||||
## [0.31.0] - 2026-01-02
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- **Transport abstraction removed**: `ProviderTransport`, `AppTransport`, and `AgentTransport` interface have been removed. Use the `streamFn` option directly for custom streaming implementations.
|
||||
|
||||
- **Agent options renamed**:
|
||||
- `transport` → removed (use `streamFn` instead)
|
||||
- `messageTransformer` → `convertToLlm`
|
||||
- `preprocessor` → `transformContext`
|
||||
|
||||
- **`AppMessage` renamed to `AgentMessage`**: All references to `AppMessage` have been renamed to `AgentMessage` for consistency.
|
||||
|
||||
- **`CustomMessages` renamed to `CustomAgentMessages`**: The declaration merging interface has been renamed.
|
||||
|
||||
- **`UserMessageWithAttachments` and `Attachment` types removed**: Attachment handling is now the responsibility of the `convertToLlm` function.
|
||||
|
||||
- **Agent loop moved from `@jaeswift/jae-ai`**: The `agentLoop`, `agentLoopContinue`, and related types have moved to this package. Import from `@jaeswift/jae-agent-core` instead.
|
||||
|
||||
### Added
|
||||
|
||||
- `streamFn` option on `Agent` for custom stream implementations. Default uses `streamSimple` from pi-ai.
|
||||
|
||||
- `streamProxy()` utility function for browser apps that need to proxy LLM calls through a backend server. Replaces the removed `AppTransport`.
|
||||
|
||||
- `getApiKey` option for dynamic API key resolution (useful for expiring OAuth tokens like GitHub Copilot).
|
||||
|
||||
- `agentLoop()` and `agentLoopContinue()` low-level functions for running the agent loop without the `Agent` class wrapper.
|
||||
|
||||
- New exported types: `AgentLoopConfig`, `AgentContext`, `AgentTool`, `AgentToolResult`, `AgentToolUpdateCallback`, `StreamFn`.
|
||||
|
||||
### Changed
|
||||
|
||||
- `Agent` constructor now has all options optional (empty options use defaults).
|
||||
|
||||
- `queueMessage()` is now synchronous (no longer returns a Promise).
|
||||
446
packages/agent/README.md
Normal file
446
packages/agent/README.md
Normal file
|
|
@ -0,0 +1,446 @@
|
|||
# @jaeswift/jae-agent-core
|
||||
|
||||
Stateful agent with tool execution and event streaming. Built on `@jaeswift/jae-ai`.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
npm install @jaeswift/jae-agent-core
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```typescript
|
||||
import { Agent } from "@jaeswift/jae-agent-core";
|
||||
import { getModel } from "@jaeswift/jae-ai";
|
||||
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant.",
|
||||
model: getModel("anthropic", "claude-sonnet-4-20250514"),
|
||||
},
|
||||
});
|
||||
|
||||
agent.subscribe((event) => {
|
||||
if (event.type === "message_update" && event.assistantMessageEvent.type === "text_delta") {
|
||||
// Stream just the new text chunk
|
||||
process.stdout.write(event.assistantMessageEvent.delta);
|
||||
}
|
||||
});
|
||||
|
||||
await agent.prompt("Hello!");
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### AgentMessage vs LLM Message
|
||||
|
||||
The agent works with `AgentMessage`, a flexible type that can include:
|
||||
- Standard LLM messages (`user`, `assistant`, `toolResult`)
|
||||
- Custom app-specific message types via declaration merging
|
||||
|
||||
LLMs only understand `user`, `assistant`, and `toolResult`. The `convertToLlm` function bridges this gap by filtering and transforming messages before each LLM call.
|
||||
|
||||
### Message Flow
|
||||
|
||||
```
|
||||
AgentMessage[] → transformContext() → AgentMessage[] → convertToLlm() → Message[] → LLM
|
||||
(optional) (required)
|
||||
```
|
||||
|
||||
1. **transformContext**: Prune old messages, inject external context
|
||||
2. **convertToLlm**: Filter out UI-only messages, convert custom types to LLM format
|
||||
|
||||
## Event Flow
|
||||
|
||||
The agent emits events for UI updates. Understanding the event sequence helps build responsive interfaces.
|
||||
|
||||
### prompt() Event Sequence
|
||||
|
||||
When you call `prompt("Hello")`:
|
||||
|
||||
```
|
||||
prompt("Hello")
|
||||
├─ agent_start
|
||||
├─ turn_start
|
||||
├─ message_start { message: userMessage } // Your prompt
|
||||
├─ message_end { message: userMessage }
|
||||
├─ message_start { message: assistantMessage } // LLM starts responding
|
||||
├─ message_update { message: partial... } // Streaming chunks
|
||||
├─ message_update { message: partial... }
|
||||
├─ message_end { message: assistantMessage } // Complete response
|
||||
├─ turn_end { message, toolResults: [] }
|
||||
└─ agent_end { messages: [...] }
|
||||
```
|
||||
|
||||
### With Tool Calls
|
||||
|
||||
If the assistant calls tools, the loop continues:
|
||||
|
||||
```
|
||||
prompt("Read config.json")
|
||||
├─ agent_start
|
||||
├─ turn_start
|
||||
├─ message_start/end { userMessage }
|
||||
├─ message_start { assistantMessage with toolCall }
|
||||
├─ message_update...
|
||||
├─ message_end { assistantMessage }
|
||||
├─ tool_execution_start { toolCallId, toolName, args }
|
||||
├─ tool_execution_update { partialResult } // If tool streams
|
||||
├─ tool_execution_end { toolCallId, result }
|
||||
├─ message_start/end { toolResultMessage }
|
||||
├─ turn_end { message, toolResults: [toolResult] }
|
||||
│
|
||||
├─ turn_start // Next turn
|
||||
├─ message_start { assistantMessage } // LLM responds to tool result
|
||||
├─ message_update...
|
||||
├─ message_end
|
||||
├─ turn_end
|
||||
└─ agent_end
|
||||
```
|
||||
|
||||
Tool execution mode is configurable:
|
||||
|
||||
- `parallel` (default): preflight tool calls sequentially, execute allowed tools concurrently, emit final `tool_execution_end` and `toolResult` messages in assistant source order
|
||||
- `sequential`: execute tool calls one by one, matching the historical behavior
|
||||
|
||||
The `beforeToolCall` hook runs after `tool_execution_start` and validated argument parsing. It can block execution. The `afterToolCall` hook runs after tool execution finishes and before `tool_execution_end` and final tool result message events are emitted.
|
||||
|
||||
When you use the `Agent` class, assistant `message_end` processing is treated as a barrier before tool preflight begins. That means `beforeToolCall` sees agent state that already includes the assistant message that requested the tool call.
|
||||
|
||||
### continue() Event Sequence
|
||||
|
||||
`continue()` resumes from existing context without adding a new message. Use it for retries after errors.
|
||||
|
||||
```typescript
|
||||
// After an error, retry from current state
|
||||
await agent.continue();
|
||||
```
|
||||
|
||||
The last message in context must be `user` or `toolResult` (not `assistant`).
|
||||
|
||||
### Event Types
|
||||
|
||||
| Event | Description |
|
||||
|-------|-------------|
|
||||
| `agent_start` | Agent begins processing |
|
||||
| `agent_end` | Agent completes with all new messages |
|
||||
| `turn_start` | New turn begins (one LLM call + tool executions) |
|
||||
| `turn_end` | Turn completes with assistant message and tool results |
|
||||
| `message_start` | Any message begins (user, assistant, toolResult) |
|
||||
| `message_update` | **Assistant only.** Includes `assistantMessageEvent` with delta |
|
||||
| `message_end` | Message completes |
|
||||
| `tool_execution_start` | Tool begins |
|
||||
| `tool_execution_update` | Tool streams progress |
|
||||
| `tool_execution_end` | Tool completes |
|
||||
|
||||
## Agent Options
|
||||
|
||||
```typescript
|
||||
const agent = new Agent({
|
||||
// Initial state
|
||||
initialState: {
|
||||
systemPrompt: string,
|
||||
model: Model<any>,
|
||||
thinkingLevel: "off" | "minimal" | "low" | "medium" | "high" | "xhigh",
|
||||
tools: AgentTool<any>[],
|
||||
messages: AgentMessage[],
|
||||
},
|
||||
|
||||
// Convert AgentMessage[] to LLM Message[] (required for custom message types)
|
||||
convertToLlm: (messages) => messages.filter(...),
|
||||
|
||||
// Transform context before convertToLlm (for pruning, compaction)
|
||||
transformContext: async (messages, signal) => pruneOldMessages(messages),
|
||||
|
||||
// Steering mode: "one-at-a-time" (default) or "all"
|
||||
steeringMode: "one-at-a-time",
|
||||
|
||||
// Follow-up mode: "one-at-a-time" (default) or "all"
|
||||
followUpMode: "one-at-a-time",
|
||||
|
||||
// Custom stream function (for proxy backends)
|
||||
streamFn: streamProxy,
|
||||
|
||||
// Session ID for provider caching
|
||||
sessionId: "session-123",
|
||||
|
||||
// Dynamic API key resolution (for expiring OAuth tokens)
|
||||
getApiKey: async (provider) => refreshToken(),
|
||||
|
||||
// Tool execution mode: "parallel" (default) or "sequential"
|
||||
toolExecution: "parallel",
|
||||
|
||||
// Preflight each tool call after args are validated. Can block execution.
|
||||
beforeToolCall: async ({ toolCall, args, context }) => {
|
||||
if (toolCall.name === "bash") {
|
||||
return { block: true, reason: "bash is disabled" };
|
||||
}
|
||||
},
|
||||
|
||||
// Postprocess each tool result before final tool events are emitted.
|
||||
afterToolCall: async ({ toolCall, result, isError, context }) => {
|
||||
if (!isError) {
|
||||
return { details: { ...result.details, audited: true } };
|
||||
}
|
||||
},
|
||||
|
||||
// Custom thinking budgets for token-based providers
|
||||
thinkingBudgets: {
|
||||
minimal: 128,
|
||||
low: 512,
|
||||
medium: 1024,
|
||||
high: 2048,
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
## Agent State
|
||||
|
||||
```typescript
|
||||
interface AgentState {
|
||||
systemPrompt: string;
|
||||
model: Model<any>;
|
||||
thinkingLevel: ThinkingLevel;
|
||||
tools: AgentTool<any>[];
|
||||
messages: AgentMessage[];
|
||||
isStreaming: boolean;
|
||||
streamMessage: AgentMessage | null; // Current partial during streaming
|
||||
pendingToolCalls: Set<string>;
|
||||
error?: string;
|
||||
}
|
||||
```
|
||||
|
||||
Access via `agent.state`. During streaming, `streamMessage` contains the partial assistant message.
|
||||
|
||||
## Methods
|
||||
|
||||
### Prompting
|
||||
|
||||
```typescript
|
||||
// Text prompt
|
||||
await agent.prompt("Hello");
|
||||
|
||||
// With images
|
||||
await agent.prompt("What's in this image?", [
|
||||
{ type: "image", data: base64Data, mimeType: "image/jpeg" }
|
||||
]);
|
||||
|
||||
// AgentMessage directly
|
||||
await agent.prompt({ role: "user", content: "Hello", timestamp: Date.now() });
|
||||
|
||||
// Continue from current context (last message must be user or toolResult)
|
||||
await agent.continue();
|
||||
```
|
||||
|
||||
### State Management
|
||||
|
||||
```typescript
|
||||
agent.setSystemPrompt("New prompt");
|
||||
agent.setModel(getModel("openai", "gpt-4o"));
|
||||
agent.setThinkingLevel("medium");
|
||||
agent.setTools([myTool]);
|
||||
agent.setToolExecution("sequential");
|
||||
agent.setBeforeToolCall(async ({ toolCall }) => undefined);
|
||||
agent.setAfterToolCall(async ({ toolCall, result }) => undefined);
|
||||
agent.replaceMessages(newMessages);
|
||||
agent.appendMessage(message);
|
||||
agent.clearMessages();
|
||||
agent.reset(); // Clear everything
|
||||
```
|
||||
|
||||
### Session and Thinking Budgets
|
||||
|
||||
```typescript
|
||||
agent.sessionId = "session-123";
|
||||
|
||||
agent.thinkingBudgets = {
|
||||
minimal: 128,
|
||||
low: 512,
|
||||
medium: 1024,
|
||||
high: 2048,
|
||||
};
|
||||
```
|
||||
|
||||
### Control
|
||||
|
||||
```typescript
|
||||
agent.abort(); // Cancel current operation
|
||||
await agent.waitForIdle(); // Wait for completion
|
||||
```
|
||||
|
||||
### Events
|
||||
|
||||
```typescript
|
||||
const unsubscribe = agent.subscribe((event) => {
|
||||
console.log(event.type);
|
||||
});
|
||||
unsubscribe();
|
||||
```
|
||||
|
||||
## Steering and Follow-up
|
||||
|
||||
Steering messages let you interrupt the agent while tools are running. Follow-up messages let you queue work after the agent would otherwise stop.
|
||||
|
||||
```typescript
|
||||
agent.setSteeringMode("one-at-a-time");
|
||||
agent.setFollowUpMode("one-at-a-time");
|
||||
|
||||
// While agent is running tools
|
||||
agent.steer({
|
||||
role: "user",
|
||||
content: "Stop! Do this instead.",
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
|
||||
// After the agent finishes its current work
|
||||
agent.followUp({
|
||||
role: "user",
|
||||
content: "Also summarize the result.",
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
|
||||
const steeringMode = agent.getSteeringMode();
|
||||
const followUpMode = agent.getFollowUpMode();
|
||||
|
||||
agent.clearSteeringQueue();
|
||||
agent.clearFollowUpQueue();
|
||||
agent.clearAllQueues();
|
||||
```
|
||||
|
||||
Use clearSteeringQueue, clearFollowUpQueue, or clearAllQueues to drop queued messages.
|
||||
|
||||
When steering messages are detected after a turn completes:
|
||||
1. All tool calls from the current assistant message have already finished
|
||||
2. Steering messages are injected
|
||||
3. The LLM responds on the next turn
|
||||
|
||||
Follow-up messages are checked only when there are no more tool calls and no steering messages. If any are queued, they are injected and another turn runs.
|
||||
|
||||
## Custom Message Types
|
||||
|
||||
Extend `AgentMessage` via declaration merging:
|
||||
|
||||
```typescript
|
||||
declare module "@jaeswift/jae-agent-core" {
|
||||
interface CustomAgentMessages {
|
||||
notification: { role: "notification"; text: string; timestamp: number };
|
||||
}
|
||||
}
|
||||
|
||||
// Now valid
|
||||
const msg: AgentMessage = { role: "notification", text: "Info", timestamp: Date.now() };
|
||||
```
|
||||
|
||||
Handle custom types in `convertToLlm`:
|
||||
|
||||
```typescript
|
||||
const agent = new Agent({
|
||||
convertToLlm: (messages) => messages.flatMap(m => {
|
||||
if (m.role === "notification") return []; // Filter out
|
||||
return [m];
|
||||
}),
|
||||
});
|
||||
```
|
||||
|
||||
## Tools
|
||||
|
||||
Define tools using `AgentTool`:
|
||||
|
||||
```typescript
|
||||
import { Type } from "@sinclair/typebox";
|
||||
|
||||
const readFileTool: AgentTool = {
|
||||
name: "read_file",
|
||||
label: "Read File", // For UI display
|
||||
description: "Read a file's contents",
|
||||
parameters: Type.Object({
|
||||
path: Type.String({ description: "File path" }),
|
||||
}),
|
||||
execute: async (toolCallId, params, signal, onUpdate) => {
|
||||
const content = await fs.readFile(params.path, "utf-8");
|
||||
|
||||
// Optional: stream progress
|
||||
onUpdate?.({ content: [{ type: "text", text: "Reading..." }], details: {} });
|
||||
|
||||
return {
|
||||
content: [{ type: "text", text: content }],
|
||||
details: { path: params.path, size: content.length },
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
agent.setTools([readFileTool]);
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
**Throw an error** when a tool fails. Do not return error messages as content.
|
||||
|
||||
```typescript
|
||||
execute: async (toolCallId, params, signal, onUpdate) => {
|
||||
if (!fs.existsSync(params.path)) {
|
||||
throw new Error(`File not found: ${params.path}`);
|
||||
}
|
||||
// Return content only on success
|
||||
return { content: [{ type: "text", text: "..." }] };
|
||||
}
|
||||
```
|
||||
|
||||
Thrown errors are caught by the agent and reported to the LLM as tool errors with `isError: true`.
|
||||
|
||||
## Proxy Usage
|
||||
|
||||
For browser apps that proxy through a backend:
|
||||
|
||||
```typescript
|
||||
import { Agent, streamProxy } from "@jaeswift/jae-agent-core";
|
||||
|
||||
const agent = new Agent({
|
||||
streamFn: (model, context, options) =>
|
||||
streamProxy(model, context, {
|
||||
...options,
|
||||
authToken: "...",
|
||||
proxyUrl: "https://your-server.com",
|
||||
}),
|
||||
});
|
||||
```
|
||||
|
||||
## Low-Level API
|
||||
|
||||
For direct control without the Agent class:
|
||||
|
||||
```typescript
|
||||
import { agentLoop, agentLoopContinue } from "@jaeswift/jae-agent-core";
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: getModel("openai", "gpt-4o"),
|
||||
convertToLlm: (msgs) => msgs.filter(m => ["user", "assistant", "toolResult"].includes(m.role)),
|
||||
toolExecution: "parallel",
|
||||
beforeToolCall: async ({ toolCall, args, context }) => undefined,
|
||||
afterToolCall: async ({ toolCall, result, isError, context }) => undefined,
|
||||
};
|
||||
|
||||
const userMessage = { role: "user", content: "Hello", timestamp: Date.now() };
|
||||
|
||||
for await (const event of agentLoop([userMessage], context, config)) {
|
||||
console.log(event.type);
|
||||
}
|
||||
|
||||
// Continue from existing context
|
||||
for await (const event of agentLoopContinue(context, config)) {
|
||||
console.log(event.type);
|
||||
}
|
||||
```
|
||||
|
||||
These low-level streams are observational. They preserve event order, but they do not wait for your async event handling to settle before later producer phases continue. If you need message processing to act as a barrier before tool preflight, use the `Agent` class instead of raw `agentLoop()` or `agentLoopContinue()`.
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
44
packages/agent/package.json
Normal file
44
packages/agent/package.json
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
{
|
||||
"name": "@jaeswift/jae-agent-core",
|
||||
"version": "0.62.0",
|
||||
"description": "General-purpose agent with transport abstraction, state management, and attachment support",
|
||||
"type": "module",
|
||||
"main": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts",
|
||||
"files": [
|
||||
"dist",
|
||||
"README.md"
|
||||
],
|
||||
"scripts": {
|
||||
"clean": "shx rm -rf dist",
|
||||
"build": "tsgo -p tsconfig.build.json",
|
||||
"dev": "tsgo -p tsconfig.build.json --watch --preserveWatchOutput",
|
||||
"test": "vitest --run",
|
||||
"prepublishOnly": "npm run clean && npm run build"
|
||||
},
|
||||
"dependencies": {
|
||||
"@jaeswift/jae-ai": "^0.62.0"
|
||||
},
|
||||
"keywords": [
|
||||
"ai",
|
||||
"agent",
|
||||
"llm",
|
||||
"transport",
|
||||
"state-management"
|
||||
],
|
||||
"author": "Mario Zechner",
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/jae/agent-jae-cli.git",
|
||||
"directory": "packages/agent"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^24.3.0",
|
||||
"typescript": "^5.7.3",
|
||||
"vitest": "^3.2.4"
|
||||
}
|
||||
}
|
||||
616
packages/agent/src/agent-loop.ts
Normal file
616
packages/agent/src/agent-loop.ts
Normal file
|
|
@ -0,0 +1,616 @@
|
|||
/**
|
||||
* Agent loop that works with AgentMessage throughout.
|
||||
* Transforms to Message[] only at the LLM call boundary.
|
||||
*/
|
||||
|
||||
import {
|
||||
type AssistantMessage,
|
||||
type Context,
|
||||
EventStream,
|
||||
streamSimple,
|
||||
type ToolResultMessage,
|
||||
validateToolArguments,
|
||||
} from "@jaeswift/jae-ai";
|
||||
import type {
|
||||
AgentContext,
|
||||
AgentEvent,
|
||||
AgentLoopConfig,
|
||||
AgentMessage,
|
||||
AgentTool,
|
||||
AgentToolCall,
|
||||
AgentToolResult,
|
||||
StreamFn,
|
||||
} from "./types.js";
|
||||
|
||||
export type AgentEventSink = (event: AgentEvent) => Promise<void> | void;
|
||||
|
||||
/**
|
||||
* Start an agent loop with a new prompt message.
|
||||
* The prompt is added to the context and events are emitted for it.
|
||||
*/
|
||||
export function agentLoop(
|
||||
prompts: AgentMessage[],
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal?: AbortSignal,
|
||||
streamFn?: StreamFn,
|
||||
): EventStream<AgentEvent, AgentMessage[]> {
|
||||
const stream = createAgentStream();
|
||||
|
||||
void runAgentLoop(
|
||||
prompts,
|
||||
context,
|
||||
config,
|
||||
async (event) => {
|
||||
stream.push(event);
|
||||
},
|
||||
signal,
|
||||
streamFn,
|
||||
).then((messages) => {
|
||||
stream.end(messages);
|
||||
});
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Continue an agent loop from the current context without adding a new message.
|
||||
* Used for retries - context already has user message or tool results.
|
||||
*
|
||||
* **Important:** The last message in context must convert to a `user` or `toolResult` message
|
||||
* via `convertToLlm`. If it doesn't, the LLM provider will reject the request.
|
||||
* This cannot be validated here since `convertToLlm` is only called once per turn.
|
||||
*/
|
||||
export function agentLoopContinue(
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal?: AbortSignal,
|
||||
streamFn?: StreamFn,
|
||||
): EventStream<AgentEvent, AgentMessage[]> {
|
||||
if (context.messages.length === 0) {
|
||||
throw new Error("Cannot continue: no messages in context");
|
||||
}
|
||||
|
||||
if (context.messages[context.messages.length - 1].role === "assistant") {
|
||||
throw new Error("Cannot continue from message role: assistant");
|
||||
}
|
||||
|
||||
const stream = createAgentStream();
|
||||
|
||||
void runAgentLoopContinue(
|
||||
context,
|
||||
config,
|
||||
async (event) => {
|
||||
stream.push(event);
|
||||
},
|
||||
signal,
|
||||
streamFn,
|
||||
).then((messages) => {
|
||||
stream.end(messages);
|
||||
});
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
export async function runAgentLoop(
|
||||
prompts: AgentMessage[],
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
emit: AgentEventSink,
|
||||
signal?: AbortSignal,
|
||||
streamFn?: StreamFn,
|
||||
): Promise<AgentMessage[]> {
|
||||
const newMessages: AgentMessage[] = [...prompts];
|
||||
const currentContext: AgentContext = {
|
||||
...context,
|
||||
messages: [...context.messages, ...prompts],
|
||||
};
|
||||
|
||||
await emit({ type: "agent_start" });
|
||||
await emit({ type: "turn_start" });
|
||||
for (const prompt of prompts) {
|
||||
await emit({ type: "message_start", message: prompt });
|
||||
await emit({ type: "message_end", message: prompt });
|
||||
}
|
||||
|
||||
await runLoop(currentContext, newMessages, config, signal, emit, streamFn);
|
||||
return newMessages;
|
||||
}
|
||||
|
||||
export async function runAgentLoopContinue(
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
emit: AgentEventSink,
|
||||
signal?: AbortSignal,
|
||||
streamFn?: StreamFn,
|
||||
): Promise<AgentMessage[]> {
|
||||
if (context.messages.length === 0) {
|
||||
throw new Error("Cannot continue: no messages in context");
|
||||
}
|
||||
|
||||
if (context.messages[context.messages.length - 1].role === "assistant") {
|
||||
throw new Error("Cannot continue from message role: assistant");
|
||||
}
|
||||
|
||||
const newMessages: AgentMessage[] = [];
|
||||
const currentContext: AgentContext = { ...context };
|
||||
|
||||
await emit({ type: "agent_start" });
|
||||
await emit({ type: "turn_start" });
|
||||
|
||||
await runLoop(currentContext, newMessages, config, signal, emit, streamFn);
|
||||
return newMessages;
|
||||
}
|
||||
|
||||
function createAgentStream(): EventStream<AgentEvent, AgentMessage[]> {
|
||||
return new EventStream<AgentEvent, AgentMessage[]>(
|
||||
(event: AgentEvent) => event.type === "agent_end",
|
||||
(event: AgentEvent) => (event.type === "agent_end" ? event.messages : []),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Main loop logic shared by agentLoop and agentLoopContinue.
|
||||
*/
|
||||
async function runLoop(
|
||||
currentContext: AgentContext,
|
||||
newMessages: AgentMessage[],
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
emit: AgentEventSink,
|
||||
streamFn?: StreamFn,
|
||||
): Promise<void> {
|
||||
let firstTurn = true;
|
||||
// Check for steering messages at start (user may have typed while waiting)
|
||||
let pendingMessages: AgentMessage[] = (await config.getSteeringMessages?.()) || [];
|
||||
|
||||
// Outer loop: continues when queued follow-up messages arrive after agent would stop
|
||||
while (true) {
|
||||
let hasMoreToolCalls = true;
|
||||
|
||||
// Inner loop: process tool calls and steering messages
|
||||
while (hasMoreToolCalls || pendingMessages.length > 0) {
|
||||
if (!firstTurn) {
|
||||
await emit({ type: "turn_start" });
|
||||
} else {
|
||||
firstTurn = false;
|
||||
}
|
||||
|
||||
// Process pending messages (inject before next assistant response)
|
||||
if (pendingMessages.length > 0) {
|
||||
for (const message of pendingMessages) {
|
||||
await emit({ type: "message_start", message });
|
||||
await emit({ type: "message_end", message });
|
||||
currentContext.messages.push(message);
|
||||
newMessages.push(message);
|
||||
}
|
||||
pendingMessages = [];
|
||||
}
|
||||
|
||||
// Stream assistant response
|
||||
const message = await streamAssistantResponse(currentContext, config, signal, emit, streamFn);
|
||||
newMessages.push(message);
|
||||
|
||||
if (message.stopReason === "error" || message.stopReason === "aborted") {
|
||||
await emit({ type: "turn_end", message, toolResults: [] });
|
||||
await emit({ type: "agent_end", messages: newMessages });
|
||||
return;
|
||||
}
|
||||
|
||||
// Check for tool calls
|
||||
const toolCalls = message.content.filter((c) => c.type === "toolCall");
|
||||
hasMoreToolCalls = toolCalls.length > 0;
|
||||
|
||||
const toolResults: ToolResultMessage[] = [];
|
||||
if (hasMoreToolCalls) {
|
||||
toolResults.push(...(await executeToolCalls(currentContext, message, config, signal, emit)));
|
||||
|
||||
for (const result of toolResults) {
|
||||
currentContext.messages.push(result);
|
||||
newMessages.push(result);
|
||||
}
|
||||
}
|
||||
|
||||
await emit({ type: "turn_end", message, toolResults });
|
||||
|
||||
pendingMessages = (await config.getSteeringMessages?.()) || [];
|
||||
}
|
||||
|
||||
// Agent would stop here. Check for follow-up messages.
|
||||
const followUpMessages = (await config.getFollowUpMessages?.()) || [];
|
||||
if (followUpMessages.length > 0) {
|
||||
// Set as pending so inner loop processes them
|
||||
pendingMessages = followUpMessages;
|
||||
continue;
|
||||
}
|
||||
|
||||
// No more messages, exit
|
||||
break;
|
||||
}
|
||||
|
||||
await emit({ type: "agent_end", messages: newMessages });
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream an assistant response from the LLM.
|
||||
* This is where AgentMessage[] gets transformed to Message[] for the LLM.
|
||||
*/
|
||||
async function streamAssistantResponse(
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
emit: AgentEventSink,
|
||||
streamFn?: StreamFn,
|
||||
): Promise<AssistantMessage> {
|
||||
// Apply context transform if configured (AgentMessage[] → AgentMessage[])
|
||||
let messages = context.messages;
|
||||
if (config.transformContext) {
|
||||
messages = await config.transformContext(messages, signal);
|
||||
}
|
||||
|
||||
// Convert to LLM-compatible messages (AgentMessage[] → Message[])
|
||||
const llmMessages = await config.convertToLlm(messages);
|
||||
|
||||
// Build LLM context
|
||||
const llmContext: Context = {
|
||||
systemPrompt: context.systemPrompt,
|
||||
messages: llmMessages,
|
||||
tools: context.tools,
|
||||
};
|
||||
|
||||
const streamFunction = streamFn || streamSimple;
|
||||
|
||||
// Resolve API key (important for expiring tokens)
|
||||
const resolvedApiKey =
|
||||
(config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) || config.apiKey;
|
||||
|
||||
const response = await streamFunction(config.model, llmContext, {
|
||||
...config,
|
||||
apiKey: resolvedApiKey,
|
||||
signal,
|
||||
});
|
||||
|
||||
let partialMessage: AssistantMessage | null = null;
|
||||
let addedPartial = false;
|
||||
|
||||
for await (const event of response) {
|
||||
switch (event.type) {
|
||||
case "start":
|
||||
partialMessage = event.partial;
|
||||
context.messages.push(partialMessage);
|
||||
addedPartial = true;
|
||||
await emit({ type: "message_start", message: { ...partialMessage } });
|
||||
break;
|
||||
|
||||
case "text_start":
|
||||
case "text_delta":
|
||||
case "text_end":
|
||||
case "thinking_start":
|
||||
case "thinking_delta":
|
||||
case "thinking_end":
|
||||
case "toolcall_start":
|
||||
case "toolcall_delta":
|
||||
case "toolcall_end":
|
||||
if (partialMessage) {
|
||||
partialMessage = event.partial;
|
||||
context.messages[context.messages.length - 1] = partialMessage;
|
||||
await emit({
|
||||
type: "message_update",
|
||||
assistantMessageEvent: event,
|
||||
message: { ...partialMessage },
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
||||
case "done":
|
||||
case "error": {
|
||||
const finalMessage = await response.result();
|
||||
if (addedPartial) {
|
||||
context.messages[context.messages.length - 1] = finalMessage;
|
||||
} else {
|
||||
context.messages.push(finalMessage);
|
||||
}
|
||||
if (!addedPartial) {
|
||||
await emit({ type: "message_start", message: { ...finalMessage } });
|
||||
}
|
||||
await emit({ type: "message_end", message: finalMessage });
|
||||
return finalMessage;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const finalMessage = await response.result();
|
||||
if (addedPartial) {
|
||||
context.messages[context.messages.length - 1] = finalMessage;
|
||||
} else {
|
||||
context.messages.push(finalMessage);
|
||||
await emit({ type: "message_start", message: { ...finalMessage } });
|
||||
}
|
||||
await emit({ type: "message_end", message: finalMessage });
|
||||
return finalMessage;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute tool calls from an assistant message.
|
||||
*/
|
||||
async function executeToolCalls(
|
||||
currentContext: AgentContext,
|
||||
assistantMessage: AssistantMessage,
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
emit: AgentEventSink,
|
||||
): Promise<ToolResultMessage[]> {
|
||||
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
|
||||
if (config.toolExecution === "sequential") {
|
||||
return executeToolCallsSequential(currentContext, assistantMessage, toolCalls, config, signal, emit);
|
||||
}
|
||||
return executeToolCallsParallel(currentContext, assistantMessage, toolCalls, config, signal, emit);
|
||||
}
|
||||
|
||||
async function executeToolCallsSequential(
|
||||
currentContext: AgentContext,
|
||||
assistantMessage: AssistantMessage,
|
||||
toolCalls: AgentToolCall[],
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
emit: AgentEventSink,
|
||||
): Promise<ToolResultMessage[]> {
|
||||
const results: ToolResultMessage[] = [];
|
||||
|
||||
for (const toolCall of toolCalls) {
|
||||
await emit({
|
||||
type: "tool_execution_start",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
});
|
||||
|
||||
const preparation = await prepareToolCall(currentContext, assistantMessage, toolCall, config, signal);
|
||||
if (preparation.kind === "immediate") {
|
||||
results.push(await emitToolCallOutcome(toolCall, preparation.result, preparation.isError, emit));
|
||||
} else {
|
||||
const executed = await executePreparedToolCall(preparation, signal, emit);
|
||||
results.push(
|
||||
await finalizeExecutedToolCall(
|
||||
currentContext,
|
||||
assistantMessage,
|
||||
preparation,
|
||||
executed,
|
||||
config,
|
||||
signal,
|
||||
emit,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
async function executeToolCallsParallel(
|
||||
currentContext: AgentContext,
|
||||
assistantMessage: AssistantMessage,
|
||||
toolCalls: AgentToolCall[],
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
emit: AgentEventSink,
|
||||
): Promise<ToolResultMessage[]> {
|
||||
const results: ToolResultMessage[] = [];
|
||||
const runnableCalls: PreparedToolCall[] = [];
|
||||
|
||||
for (const toolCall of toolCalls) {
|
||||
await emit({
|
||||
type: "tool_execution_start",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
});
|
||||
|
||||
const preparation = await prepareToolCall(currentContext, assistantMessage, toolCall, config, signal);
|
||||
if (preparation.kind === "immediate") {
|
||||
results.push(await emitToolCallOutcome(toolCall, preparation.result, preparation.isError, emit));
|
||||
} else {
|
||||
runnableCalls.push(preparation);
|
||||
}
|
||||
}
|
||||
|
||||
const runningCalls = runnableCalls.map((prepared) => ({
|
||||
prepared,
|
||||
execution: executePreparedToolCall(prepared, signal, emit),
|
||||
}));
|
||||
|
||||
for (const running of runningCalls) {
|
||||
const executed = await running.execution;
|
||||
results.push(
|
||||
await finalizeExecutedToolCall(
|
||||
currentContext,
|
||||
assistantMessage,
|
||||
running.prepared,
|
||||
executed,
|
||||
config,
|
||||
signal,
|
||||
emit,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
type PreparedToolCall = {
|
||||
kind: "prepared";
|
||||
toolCall: AgentToolCall;
|
||||
tool: AgentTool<any>;
|
||||
args: unknown;
|
||||
};
|
||||
|
||||
type ImmediateToolCallOutcome = {
|
||||
kind: "immediate";
|
||||
result: AgentToolResult<any>;
|
||||
isError: boolean;
|
||||
};
|
||||
|
||||
type ExecutedToolCallOutcome = {
|
||||
result: AgentToolResult<any>;
|
||||
isError: boolean;
|
||||
};
|
||||
|
||||
async function prepareToolCall(
|
||||
currentContext: AgentContext,
|
||||
assistantMessage: AssistantMessage,
|
||||
toolCall: AgentToolCall,
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
): Promise<PreparedToolCall | ImmediateToolCallOutcome> {
|
||||
const tool = currentContext.tools?.find((t) => t.name === toolCall.name);
|
||||
if (!tool) {
|
||||
return {
|
||||
kind: "immediate",
|
||||
result: createErrorToolResult(`Tool ${toolCall.name} not found`),
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const validatedArgs = validateToolArguments(tool, toolCall);
|
||||
if (config.beforeToolCall) {
|
||||
const beforeResult = await config.beforeToolCall(
|
||||
{
|
||||
assistantMessage,
|
||||
toolCall,
|
||||
args: validatedArgs,
|
||||
context: currentContext,
|
||||
},
|
||||
signal,
|
||||
);
|
||||
if (beforeResult?.block) {
|
||||
return {
|
||||
kind: "immediate",
|
||||
result: createErrorToolResult(beforeResult.reason || "Tool execution was blocked"),
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
return {
|
||||
kind: "prepared",
|
||||
toolCall,
|
||||
tool,
|
||||
args: validatedArgs,
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
kind: "immediate",
|
||||
result: createErrorToolResult(error instanceof Error ? error.message : String(error)),
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async function executePreparedToolCall(
|
||||
prepared: PreparedToolCall,
|
||||
signal: AbortSignal | undefined,
|
||||
emit: AgentEventSink,
|
||||
): Promise<ExecutedToolCallOutcome> {
|
||||
const updateEvents: Promise<void>[] = [];
|
||||
|
||||
try {
|
||||
const result = await prepared.tool.execute(
|
||||
prepared.toolCall.id,
|
||||
prepared.args as never,
|
||||
signal,
|
||||
(partialResult) => {
|
||||
updateEvents.push(
|
||||
Promise.resolve(
|
||||
emit({
|
||||
type: "tool_execution_update",
|
||||
toolCallId: prepared.toolCall.id,
|
||||
toolName: prepared.toolCall.name,
|
||||
args: prepared.toolCall.arguments,
|
||||
partialResult,
|
||||
}),
|
||||
),
|
||||
);
|
||||
},
|
||||
);
|
||||
await Promise.all(updateEvents);
|
||||
return { result, isError: false };
|
||||
} catch (error) {
|
||||
await Promise.all(updateEvents);
|
||||
return {
|
||||
result: createErrorToolResult(error instanceof Error ? error.message : String(error)),
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async function finalizeExecutedToolCall(
|
||||
currentContext: AgentContext,
|
||||
assistantMessage: AssistantMessage,
|
||||
prepared: PreparedToolCall,
|
||||
executed: ExecutedToolCallOutcome,
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
emit: AgentEventSink,
|
||||
): Promise<ToolResultMessage> {
|
||||
let result = executed.result;
|
||||
let isError = executed.isError;
|
||||
|
||||
if (config.afterToolCall) {
|
||||
const afterResult = await config.afterToolCall(
|
||||
{
|
||||
assistantMessage,
|
||||
toolCall: prepared.toolCall,
|
||||
args: prepared.args,
|
||||
result,
|
||||
isError,
|
||||
context: currentContext,
|
||||
},
|
||||
signal,
|
||||
);
|
||||
if (afterResult) {
|
||||
result = {
|
||||
content: afterResult.content ?? result.content,
|
||||
details: afterResult.details ?? result.details,
|
||||
};
|
||||
isError = afterResult.isError ?? isError;
|
||||
}
|
||||
}
|
||||
|
||||
return await emitToolCallOutcome(prepared.toolCall, result, isError, emit);
|
||||
}
|
||||
|
||||
function createErrorToolResult(message: string): AgentToolResult<any> {
|
||||
return {
|
||||
content: [{ type: "text", text: message }],
|
||||
details: {},
|
||||
};
|
||||
}
|
||||
|
||||
async function emitToolCallOutcome(
|
||||
toolCall: AgentToolCall,
|
||||
result: AgentToolResult<any>,
|
||||
isError: boolean,
|
||||
emit: AgentEventSink,
|
||||
): Promise<ToolResultMessage> {
|
||||
await emit({
|
||||
type: "tool_execution_end",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
result,
|
||||
isError,
|
||||
});
|
||||
|
||||
const toolResultMessage: ToolResultMessage = {
|
||||
role: "toolResult",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
content: result.content,
|
||||
details: result.details,
|
||||
isError,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
await emit({ type: "message_start", message: toolResultMessage });
|
||||
await emit({ type: "message_end", message: toolResultMessage });
|
||||
return toolResultMessage;
|
||||
}
|
||||
613
packages/agent/src/agent.ts
Normal file
613
packages/agent/src/agent.ts
Normal file
|
|
@ -0,0 +1,613 @@
|
|||
/**
|
||||
* Agent class that uses the agent-loop directly.
|
||||
* No transport abstraction - calls streamSimple via the loop.
|
||||
*/
|
||||
|
||||
import {
|
||||
getModel,
|
||||
type ImageContent,
|
||||
type Message,
|
||||
type Model,
|
||||
type SimpleStreamOptions,
|
||||
streamSimple,
|
||||
type TextContent,
|
||||
type ThinkingBudgets,
|
||||
type Transport,
|
||||
} from "@jaeswift/jae-ai";
|
||||
import { runAgentLoop, runAgentLoopContinue } from "./agent-loop.js";
|
||||
import type {
|
||||
AfterToolCallContext,
|
||||
AfterToolCallResult,
|
||||
AgentContext,
|
||||
AgentEvent,
|
||||
AgentLoopConfig,
|
||||
AgentMessage,
|
||||
AgentState,
|
||||
AgentTool,
|
||||
BeforeToolCallContext,
|
||||
BeforeToolCallResult,
|
||||
StreamFn,
|
||||
ThinkingLevel,
|
||||
ToolExecutionMode,
|
||||
} from "./types.js";
|
||||
|
||||
/**
|
||||
* Default convertToLlm: Keep only LLM-compatible messages, convert attachments.
|
||||
*/
|
||||
function defaultConvertToLlm(messages: AgentMessage[]): Message[] {
|
||||
return messages.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult");
|
||||
}
|
||||
|
||||
export interface AgentOptions {
|
||||
initialState?: Partial<AgentState>;
|
||||
|
||||
/**
|
||||
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
|
||||
* Default filters to user/assistant/toolResult and converts attachments.
|
||||
*/
|
||||
convertToLlm?: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
||||
|
||||
/**
|
||||
* Optional transform applied to context before convertToLlm.
|
||||
* Use for context pruning, injecting external context, etc.
|
||||
*/
|
||||
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
||||
|
||||
/**
|
||||
* Steering mode: "all" = send all steering messages at once, "one-at-a-time" = one per turn
|
||||
*/
|
||||
steeringMode?: "all" | "one-at-a-time";
|
||||
|
||||
/**
|
||||
* Follow-up mode: "all" = send all follow-up messages at once, "one-at-a-time" = one per turn
|
||||
*/
|
||||
followUpMode?: "all" | "one-at-a-time";
|
||||
|
||||
/**
|
||||
* Custom stream function (for proxy backends, etc.). Default uses streamSimple.
|
||||
*/
|
||||
streamFn?: StreamFn;
|
||||
|
||||
/**
|
||||
* Optional session identifier forwarded to LLM providers.
|
||||
* Used by providers that support session-based caching (e.g., OpenAI Codex).
|
||||
*/
|
||||
sessionId?: string;
|
||||
|
||||
/**
|
||||
* Resolves an API key dynamically for each LLM call.
|
||||
* Useful for expiring tokens (e.g., GitHub Copilot OAuth).
|
||||
*/
|
||||
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
||||
|
||||
/**
|
||||
* Inspect or replace provider payloads before they are sent.
|
||||
*/
|
||||
onPayload?: SimpleStreamOptions["onPayload"];
|
||||
|
||||
/**
|
||||
* Custom token budgets for thinking levels (token-based providers only).
|
||||
*/
|
||||
thinkingBudgets?: ThinkingBudgets;
|
||||
|
||||
/**
|
||||
* Preferred transport for providers that support multiple transports.
|
||||
*/
|
||||
transport?: Transport;
|
||||
|
||||
/**
|
||||
* Maximum delay in milliseconds to wait for a retry when the server requests a long wait.
|
||||
* If the server's requested delay exceeds this value, the request fails immediately,
|
||||
* allowing higher-level retry logic to handle it with user visibility.
|
||||
* Default: 60000 (60 seconds). Set to 0 to disable the cap.
|
||||
*/
|
||||
maxRetryDelayMs?: number;
|
||||
|
||||
/** Tool execution mode. Default: "parallel" */
|
||||
toolExecution?: ToolExecutionMode;
|
||||
|
||||
/** Called before a tool is executed, after arguments have been validated. */
|
||||
beforeToolCall?: (context: BeforeToolCallContext, signal?: AbortSignal) => Promise<BeforeToolCallResult | undefined>;
|
||||
|
||||
/** Called after a tool finishes executing, before final tool events are emitted. */
|
||||
afterToolCall?: (context: AfterToolCallContext, signal?: AbortSignal) => Promise<AfterToolCallResult | undefined>;
|
||||
}
|
||||
|
||||
export class Agent {
|
||||
private _state: AgentState = {
|
||||
systemPrompt: "",
|
||||
model: getModel("google", "gemini-2.5-flash-lite-preview-06-17"),
|
||||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
messages: [],
|
||||
isStreaming: false,
|
||||
streamMessage: null,
|
||||
pendingToolCalls: new Set<string>(),
|
||||
error: undefined,
|
||||
};
|
||||
|
||||
private listeners = new Set<(e: AgentEvent) => void>();
|
||||
private abortController?: AbortController;
|
||||
private convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
||||
private transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
||||
private steeringQueue: AgentMessage[] = [];
|
||||
private followUpQueue: AgentMessage[] = [];
|
||||
private steeringMode: "all" | "one-at-a-time";
|
||||
private followUpMode: "all" | "one-at-a-time";
|
||||
public streamFn: StreamFn;
|
||||
private _sessionId?: string;
|
||||
public getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
||||
private _onPayload?: SimpleStreamOptions["onPayload"];
|
||||
private runningPrompt?: Promise<void>;
|
||||
private resolveRunningPrompt?: () => void;
|
||||
private _thinkingBudgets?: ThinkingBudgets;
|
||||
private _transport: Transport;
|
||||
private _maxRetryDelayMs?: number;
|
||||
private _toolExecution: ToolExecutionMode;
|
||||
private _beforeToolCall?: (
|
||||
context: BeforeToolCallContext,
|
||||
signal?: AbortSignal,
|
||||
) => Promise<BeforeToolCallResult | undefined>;
|
||||
private _afterToolCall?: (
|
||||
context: AfterToolCallContext,
|
||||
signal?: AbortSignal,
|
||||
) => Promise<AfterToolCallResult | undefined>;
|
||||
|
||||
constructor(opts: AgentOptions = {}) {
|
||||
this._state = { ...this._state, ...opts.initialState };
|
||||
this.convertToLlm = opts.convertToLlm || defaultConvertToLlm;
|
||||
this.transformContext = opts.transformContext;
|
||||
this.steeringMode = opts.steeringMode || "one-at-a-time";
|
||||
this.followUpMode = opts.followUpMode || "one-at-a-time";
|
||||
this.streamFn = opts.streamFn || streamSimple;
|
||||
this._sessionId = opts.sessionId;
|
||||
this.getApiKey = opts.getApiKey;
|
||||
this._onPayload = opts.onPayload;
|
||||
this._thinkingBudgets = opts.thinkingBudgets;
|
||||
this._transport = opts.transport ?? "sse";
|
||||
this._maxRetryDelayMs = opts.maxRetryDelayMs;
|
||||
this._toolExecution = opts.toolExecution ?? "parallel";
|
||||
this._beforeToolCall = opts.beforeToolCall;
|
||||
this._afterToolCall = opts.afterToolCall;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current session ID used for provider caching.
|
||||
*/
|
||||
get sessionId(): string | undefined {
|
||||
return this._sessionId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the session ID for provider caching.
|
||||
* Call this when switching sessions (new session, branch, resume).
|
||||
*/
|
||||
set sessionId(value: string | undefined) {
|
||||
this._sessionId = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current thinking budgets.
|
||||
*/
|
||||
get thinkingBudgets(): ThinkingBudgets | undefined {
|
||||
return this._thinkingBudgets;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set custom thinking budgets for token-based providers.
|
||||
*/
|
||||
set thinkingBudgets(value: ThinkingBudgets | undefined) {
|
||||
this._thinkingBudgets = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current preferred transport.
|
||||
*/
|
||||
get transport(): Transport {
|
||||
return this._transport;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the preferred transport.
|
||||
*/
|
||||
setTransport(value: Transport) {
|
||||
this._transport = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current max retry delay in milliseconds.
|
||||
*/
|
||||
get maxRetryDelayMs(): number | undefined {
|
||||
return this._maxRetryDelayMs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the maximum delay to wait for server-requested retries.
|
||||
* Set to 0 to disable the cap.
|
||||
*/
|
||||
set maxRetryDelayMs(value: number | undefined) {
|
||||
this._maxRetryDelayMs = value;
|
||||
}
|
||||
|
||||
get toolExecution(): ToolExecutionMode {
|
||||
return this._toolExecution;
|
||||
}
|
||||
|
||||
setToolExecution(value: ToolExecutionMode) {
|
||||
this._toolExecution = value;
|
||||
}
|
||||
|
||||
setBeforeToolCall(
|
||||
value:
|
||||
| ((context: BeforeToolCallContext, signal?: AbortSignal) => Promise<BeforeToolCallResult | undefined>)
|
||||
| undefined,
|
||||
) {
|
||||
this._beforeToolCall = value;
|
||||
}
|
||||
|
||||
setAfterToolCall(
|
||||
value:
|
||||
| ((context: AfterToolCallContext, signal?: AbortSignal) => Promise<AfterToolCallResult | undefined>)
|
||||
| undefined,
|
||||
) {
|
||||
this._afterToolCall = value;
|
||||
}
|
||||
|
||||
get state(): AgentState {
|
||||
return this._state;
|
||||
}
|
||||
|
||||
subscribe(fn: (e: AgentEvent) => void): () => void {
|
||||
this.listeners.add(fn);
|
||||
return () => this.listeners.delete(fn);
|
||||
}
|
||||
|
||||
// State mutators
|
||||
setSystemPrompt(v: string) {
|
||||
this._state.systemPrompt = v;
|
||||
}
|
||||
|
||||
setModel(m: Model<any>) {
|
||||
this._state.model = m;
|
||||
}
|
||||
|
||||
setThinkingLevel(l: ThinkingLevel) {
|
||||
this._state.thinkingLevel = l;
|
||||
}
|
||||
|
||||
setSteeringMode(mode: "all" | "one-at-a-time") {
|
||||
this.steeringMode = mode;
|
||||
}
|
||||
|
||||
getSteeringMode(): "all" | "one-at-a-time" {
|
||||
return this.steeringMode;
|
||||
}
|
||||
|
||||
setFollowUpMode(mode: "all" | "one-at-a-time") {
|
||||
this.followUpMode = mode;
|
||||
}
|
||||
|
||||
getFollowUpMode(): "all" | "one-at-a-time" {
|
||||
return this.followUpMode;
|
||||
}
|
||||
|
||||
setTools(t: AgentTool<any>[]) {
|
||||
this._state.tools = t;
|
||||
}
|
||||
|
||||
replaceMessages(ms: AgentMessage[]) {
|
||||
this._state.messages = ms.slice();
|
||||
}
|
||||
|
||||
appendMessage(m: AgentMessage) {
|
||||
this._state.messages = [...this._state.messages, m];
|
||||
}
|
||||
|
||||
/**
|
||||
* Queue a steering message while the agent is running.
|
||||
* Delivered after the current assistant turn finishes executing its tool calls,
|
||||
* before the next LLM call.
|
||||
*/
|
||||
steer(m: AgentMessage) {
|
||||
this.steeringQueue.push(m);
|
||||
}
|
||||
|
||||
/**
|
||||
* Queue a follow-up message to be processed after the agent finishes.
|
||||
* Delivered only when agent has no more tool calls or steering messages.
|
||||
*/
|
||||
followUp(m: AgentMessage) {
|
||||
this.followUpQueue.push(m);
|
||||
}
|
||||
|
||||
clearSteeringQueue() {
|
||||
this.steeringQueue = [];
|
||||
}
|
||||
|
||||
clearFollowUpQueue() {
|
||||
this.followUpQueue = [];
|
||||
}
|
||||
|
||||
clearAllQueues() {
|
||||
this.steeringQueue = [];
|
||||
this.followUpQueue = [];
|
||||
}
|
||||
|
||||
hasQueuedMessages(): boolean {
|
||||
return this.steeringQueue.length > 0 || this.followUpQueue.length > 0;
|
||||
}
|
||||
|
||||
private dequeueSteeringMessages(): AgentMessage[] {
|
||||
if (this.steeringMode === "one-at-a-time") {
|
||||
if (this.steeringQueue.length > 0) {
|
||||
const first = this.steeringQueue[0];
|
||||
this.steeringQueue = this.steeringQueue.slice(1);
|
||||
return [first];
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
const steering = this.steeringQueue.slice();
|
||||
this.steeringQueue = [];
|
||||
return steering;
|
||||
}
|
||||
|
||||
private dequeueFollowUpMessages(): AgentMessage[] {
|
||||
if (this.followUpMode === "one-at-a-time") {
|
||||
if (this.followUpQueue.length > 0) {
|
||||
const first = this.followUpQueue[0];
|
||||
this.followUpQueue = this.followUpQueue.slice(1);
|
||||
return [first];
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
const followUp = this.followUpQueue.slice();
|
||||
this.followUpQueue = [];
|
||||
return followUp;
|
||||
}
|
||||
|
||||
clearMessages() {
|
||||
this._state.messages = [];
|
||||
}
|
||||
|
||||
abort() {
|
||||
this.abortController?.abort();
|
||||
}
|
||||
|
||||
waitForIdle(): Promise<void> {
|
||||
return this.runningPrompt ?? Promise.resolve();
|
||||
}
|
||||
|
||||
reset() {
|
||||
this._state.messages = [];
|
||||
this._state.isStreaming = false;
|
||||
this._state.streamMessage = null;
|
||||
this._state.pendingToolCalls = new Set<string>();
|
||||
this._state.error = undefined;
|
||||
this.steeringQueue = [];
|
||||
this.followUpQueue = [];
|
||||
}
|
||||
|
||||
/** Send a prompt with an AgentMessage */
|
||||
async prompt(message: AgentMessage | AgentMessage[]): Promise<void>;
|
||||
async prompt(input: string, images?: ImageContent[]): Promise<void>;
|
||||
async prompt(input: string | AgentMessage | AgentMessage[], images?: ImageContent[]) {
|
||||
if (this._state.isStreaming) {
|
||||
throw new Error(
|
||||
"Agent is already processing a prompt. Use steer() or followUp() to queue messages, or wait for completion.",
|
||||
);
|
||||
}
|
||||
|
||||
const model = this._state.model;
|
||||
if (!model) throw new Error("No model configured");
|
||||
|
||||
let msgs: AgentMessage[];
|
||||
|
||||
if (Array.isArray(input)) {
|
||||
msgs = input;
|
||||
} else if (typeof input === "string") {
|
||||
const content: Array<TextContent | ImageContent> = [{ type: "text", text: input }];
|
||||
if (images && images.length > 0) {
|
||||
content.push(...images);
|
||||
}
|
||||
msgs = [
|
||||
{
|
||||
role: "user",
|
||||
content,
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
} else {
|
||||
msgs = [input];
|
||||
}
|
||||
|
||||
await this._runLoop(msgs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Continue from current context (used for retries and resuming queued messages).
|
||||
*/
|
||||
async continue() {
|
||||
if (this._state.isStreaming) {
|
||||
throw new Error("Agent is already processing. Wait for completion before continuing.");
|
||||
}
|
||||
|
||||
const messages = this._state.messages;
|
||||
if (messages.length === 0) {
|
||||
throw new Error("No messages to continue from");
|
||||
}
|
||||
if (messages[messages.length - 1].role === "assistant") {
|
||||
const queuedSteering = this.dequeueSteeringMessages();
|
||||
if (queuedSteering.length > 0) {
|
||||
await this._runLoop(queuedSteering, { skipInitialSteeringPoll: true });
|
||||
return;
|
||||
}
|
||||
|
||||
const queuedFollowUp = this.dequeueFollowUpMessages();
|
||||
if (queuedFollowUp.length > 0) {
|
||||
await this._runLoop(queuedFollowUp);
|
||||
return;
|
||||
}
|
||||
|
||||
throw new Error("Cannot continue from message role: assistant");
|
||||
}
|
||||
|
||||
await this._runLoop(undefined);
|
||||
}
|
||||
|
||||
private _processLoopEvent(event: AgentEvent): void {
|
||||
switch (event.type) {
|
||||
case "message_start":
|
||||
this._state.streamMessage = event.message;
|
||||
break;
|
||||
|
||||
case "message_update":
|
||||
this._state.streamMessage = event.message;
|
||||
break;
|
||||
|
||||
case "message_end":
|
||||
this._state.streamMessage = null;
|
||||
this.appendMessage(event.message);
|
||||
break;
|
||||
|
||||
case "tool_execution_start": {
|
||||
const pendingToolCalls = new Set(this._state.pendingToolCalls);
|
||||
pendingToolCalls.add(event.toolCallId);
|
||||
this._state.pendingToolCalls = pendingToolCalls;
|
||||
break;
|
||||
}
|
||||
|
||||
case "tool_execution_end": {
|
||||
const pendingToolCalls = new Set(this._state.pendingToolCalls);
|
||||
pendingToolCalls.delete(event.toolCallId);
|
||||
this._state.pendingToolCalls = pendingToolCalls;
|
||||
break;
|
||||
}
|
||||
|
||||
case "turn_end":
|
||||
if (event.message.role === "assistant" && (event.message as any).errorMessage) {
|
||||
this._state.error = (event.message as any).errorMessage;
|
||||
}
|
||||
break;
|
||||
|
||||
case "agent_end":
|
||||
this._state.isStreaming = false;
|
||||
this._state.streamMessage = null;
|
||||
break;
|
||||
}
|
||||
|
||||
this.emit(event);
|
||||
}
|
||||
|
||||
/**
|
||||
* Run the agent loop.
|
||||
* If messages are provided, starts a new conversation turn with those messages.
|
||||
* Otherwise, continues from existing context.
|
||||
*/
|
||||
private async _runLoop(messages?: AgentMessage[], options?: { skipInitialSteeringPoll?: boolean }) {
|
||||
const model = this._state.model;
|
||||
if (!model) throw new Error("No model configured");
|
||||
|
||||
this.runningPrompt = new Promise<void>((resolve) => {
|
||||
this.resolveRunningPrompt = resolve;
|
||||
});
|
||||
|
||||
this.abortController = new AbortController();
|
||||
this._state.isStreaming = true;
|
||||
this._state.streamMessage = null;
|
||||
this._state.error = undefined;
|
||||
|
||||
const reasoning = this._state.thinkingLevel === "off" ? undefined : this._state.thinkingLevel;
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: this._state.systemPrompt,
|
||||
messages: this._state.messages.slice(),
|
||||
tools: this._state.tools,
|
||||
};
|
||||
|
||||
let skipInitialSteeringPoll = options?.skipInitialSteeringPoll === true;
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model,
|
||||
reasoning,
|
||||
sessionId: this._sessionId,
|
||||
onPayload: this._onPayload,
|
||||
transport: this._transport,
|
||||
thinkingBudgets: this._thinkingBudgets,
|
||||
maxRetryDelayMs: this._maxRetryDelayMs,
|
||||
toolExecution: this._toolExecution,
|
||||
beforeToolCall: this._beforeToolCall,
|
||||
afterToolCall: this._afterToolCall,
|
||||
convertToLlm: this.convertToLlm,
|
||||
transformContext: this.transformContext,
|
||||
getApiKey: this.getApiKey,
|
||||
getSteeringMessages: async () => {
|
||||
if (skipInitialSteeringPoll) {
|
||||
skipInitialSteeringPoll = false;
|
||||
return [];
|
||||
}
|
||||
return this.dequeueSteeringMessages();
|
||||
},
|
||||
getFollowUpMessages: async () => this.dequeueFollowUpMessages(),
|
||||
};
|
||||
|
||||
try {
|
||||
if (messages) {
|
||||
await runAgentLoop(
|
||||
messages,
|
||||
context,
|
||||
config,
|
||||
async (event) => this._processLoopEvent(event),
|
||||
this.abortController.signal,
|
||||
this.streamFn,
|
||||
);
|
||||
} else {
|
||||
await runAgentLoopContinue(
|
||||
context,
|
||||
config,
|
||||
async (event) => this._processLoopEvent(event),
|
||||
this.abortController.signal,
|
||||
this.streamFn,
|
||||
);
|
||||
}
|
||||
} catch (err: any) {
|
||||
const errorMsg: AgentMessage = {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "" }],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: this.abortController?.signal.aborted ? "aborted" : "error",
|
||||
errorMessage: err?.message || String(err),
|
||||
timestamp: Date.now(),
|
||||
} as AgentMessage;
|
||||
|
||||
this.appendMessage(errorMsg);
|
||||
this._state.error = err?.message || String(err);
|
||||
this.emit({ type: "agent_end", messages: [errorMsg] });
|
||||
} finally {
|
||||
this._state.isStreaming = false;
|
||||
this._state.streamMessage = null;
|
||||
this._state.pendingToolCalls = new Set<string>();
|
||||
this.abortController = undefined;
|
||||
this.resolveRunningPrompt?.();
|
||||
this.runningPrompt = undefined;
|
||||
this.resolveRunningPrompt = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
private emit(e: AgentEvent) {
|
||||
for (const listener of this.listeners) {
|
||||
listener(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
8
packages/agent/src/index.ts
Normal file
8
packages/agent/src/index.ts
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
// Core Agent
|
||||
export * from "./agent.js";
|
||||
// Loop functions
|
||||
export * from "./agent-loop.js";
|
||||
// Proxy utilities
|
||||
export * from "./proxy.js";
|
||||
// Types
|
||||
export * from "./types.js";
|
||||
340
packages/agent/src/proxy.ts
Normal file
340
packages/agent/src/proxy.ts
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
/**
|
||||
* Proxy stream function for apps that route LLM calls through a server.
|
||||
* The server manages auth and proxies requests to LLM providers.
|
||||
*/
|
||||
|
||||
// Internal import for JSON parsing utility
|
||||
import {
|
||||
type AssistantMessage,
|
||||
type AssistantMessageEvent,
|
||||
type Context,
|
||||
EventStream,
|
||||
type Model,
|
||||
parseStreamingJson,
|
||||
type SimpleStreamOptions,
|
||||
type StopReason,
|
||||
type ToolCall,
|
||||
} from "@jaeswift/jae-ai";
|
||||
|
||||
// Create stream class matching ProxyMessageEventStream
|
||||
class ProxyMessageEventStream extends EventStream<AssistantMessageEvent, AssistantMessage> {
|
||||
constructor() {
|
||||
super(
|
||||
(event) => event.type === "done" || event.type === "error",
|
||||
(event) => {
|
||||
if (event.type === "done") return event.message;
|
||||
if (event.type === "error") return event.error;
|
||||
throw new Error("Unexpected event type");
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Proxy event types - server sends these with partial field stripped to reduce bandwidth.
|
||||
*/
|
||||
export type ProxyAssistantMessageEvent =
|
||||
| { type: "start" }
|
||||
| { type: "text_start"; contentIndex: number }
|
||||
| { type: "text_delta"; contentIndex: number; delta: string }
|
||||
| { type: "text_end"; contentIndex: number; contentSignature?: string }
|
||||
| { type: "thinking_start"; contentIndex: number }
|
||||
| { type: "thinking_delta"; contentIndex: number; delta: string }
|
||||
| { type: "thinking_end"; contentIndex: number; contentSignature?: string }
|
||||
| { type: "toolcall_start"; contentIndex: number; id: string; toolName: string }
|
||||
| { type: "toolcall_delta"; contentIndex: number; delta: string }
|
||||
| { type: "toolcall_end"; contentIndex: number }
|
||||
| {
|
||||
type: "done";
|
||||
reason: Extract<StopReason, "stop" | "length" | "toolUse">;
|
||||
usage: AssistantMessage["usage"];
|
||||
}
|
||||
| {
|
||||
type: "error";
|
||||
reason: Extract<StopReason, "aborted" | "error">;
|
||||
errorMessage?: string;
|
||||
usage: AssistantMessage["usage"];
|
||||
};
|
||||
|
||||
export interface ProxyStreamOptions extends SimpleStreamOptions {
|
||||
/** Auth token for the proxy server */
|
||||
authToken: string;
|
||||
/** Proxy server URL (e.g., "https://genai.example.com") */
|
||||
proxyUrl: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream function that proxies through a server instead of calling LLM providers directly.
|
||||
* The server strips the partial field from delta events to reduce bandwidth.
|
||||
* We reconstruct the partial message client-side.
|
||||
*
|
||||
* Use this as the `streamFn` option when creating an Agent that needs to go through a proxy.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const agent = new Agent({
|
||||
* streamFn: (model, context, options) =>
|
||||
* streamProxy(model, context, {
|
||||
* ...options,
|
||||
* authToken: await getAuthToken(),
|
||||
* proxyUrl: "https://genai.example.com",
|
||||
* }),
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
export function streamProxy(model: Model<any>, context: Context, options: ProxyStreamOptions): ProxyMessageEventStream {
|
||||
const stream = new ProxyMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
// Initialize the partial message that we'll build up from events
|
||||
const partial: AssistantMessage = {
|
||||
role: "assistant",
|
||||
stopReason: "stop",
|
||||
content: [],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
let reader: ReadableStreamDefaultReader<Uint8Array> | undefined;
|
||||
|
||||
const abortHandler = () => {
|
||||
if (reader) {
|
||||
reader.cancel("Request aborted by user").catch(() => {});
|
||||
}
|
||||
};
|
||||
|
||||
if (options.signal) {
|
||||
options.signal.addEventListener("abort", abortHandler);
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`${options.proxyUrl}/api/stream`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${options.authToken}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model,
|
||||
context,
|
||||
options: {
|
||||
temperature: options.temperature,
|
||||
maxTokens: options.maxTokens,
|
||||
reasoning: options.reasoning,
|
||||
},
|
||||
}),
|
||||
signal: options.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorMessage = `Proxy error: ${response.status} ${response.statusText}`;
|
||||
try {
|
||||
const errorData = (await response.json()) as { error?: string };
|
||||
if (errorData.error) {
|
||||
errorMessage = `Proxy error: ${errorData.error}`;
|
||||
}
|
||||
} catch {
|
||||
// Couldn't parse error response
|
||||
}
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
|
||||
reader = response.body!.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Request aborted by user");
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() || "";
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ")) {
|
||||
const data = line.slice(6).trim();
|
||||
if (data) {
|
||||
const proxyEvent = JSON.parse(data) as ProxyAssistantMessageEvent;
|
||||
const event = processProxyEvent(proxyEvent, partial);
|
||||
if (event) {
|
||||
stream.push(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Request aborted by user");
|
||||
}
|
||||
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
const reason = options.signal?.aborted ? "aborted" : "error";
|
||||
partial.stopReason = reason;
|
||||
partial.errorMessage = errorMessage;
|
||||
stream.push({
|
||||
type: "error",
|
||||
reason,
|
||||
error: partial,
|
||||
});
|
||||
stream.end();
|
||||
} finally {
|
||||
if (options.signal) {
|
||||
options.signal.removeEventListener("abort", abortHandler);
|
||||
}
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a proxy event and update the partial message.
|
||||
*/
|
||||
function processProxyEvent(
|
||||
proxyEvent: ProxyAssistantMessageEvent,
|
||||
partial: AssistantMessage,
|
||||
): AssistantMessageEvent | undefined {
|
||||
switch (proxyEvent.type) {
|
||||
case "start":
|
||||
return { type: "start", partial };
|
||||
|
||||
case "text_start":
|
||||
partial.content[proxyEvent.contentIndex] = { type: "text", text: "" };
|
||||
return { type: "text_start", contentIndex: proxyEvent.contentIndex, partial };
|
||||
|
||||
case "text_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "text") {
|
||||
content.text += proxyEvent.delta;
|
||||
return {
|
||||
type: "text_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received text_delta for non-text content");
|
||||
}
|
||||
|
||||
case "text_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "text") {
|
||||
content.textSignature = proxyEvent.contentSignature;
|
||||
return {
|
||||
type: "text_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
content: content.text,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received text_end for non-text content");
|
||||
}
|
||||
|
||||
case "thinking_start":
|
||||
partial.content[proxyEvent.contentIndex] = { type: "thinking", thinking: "" };
|
||||
return { type: "thinking_start", contentIndex: proxyEvent.contentIndex, partial };
|
||||
|
||||
case "thinking_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "thinking") {
|
||||
content.thinking += proxyEvent.delta;
|
||||
return {
|
||||
type: "thinking_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received thinking_delta for non-thinking content");
|
||||
}
|
||||
|
||||
case "thinking_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "thinking") {
|
||||
content.thinkingSignature = proxyEvent.contentSignature;
|
||||
return {
|
||||
type: "thinking_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
content: content.thinking,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received thinking_end for non-thinking content");
|
||||
}
|
||||
|
||||
case "toolcall_start":
|
||||
partial.content[proxyEvent.contentIndex] = {
|
||||
type: "toolCall",
|
||||
id: proxyEvent.id,
|
||||
name: proxyEvent.toolName,
|
||||
arguments: {},
|
||||
partialJson: "",
|
||||
} satisfies ToolCall & { partialJson: string } as ToolCall;
|
||||
return { type: "toolcall_start", contentIndex: proxyEvent.contentIndex, partial };
|
||||
|
||||
case "toolcall_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "toolCall") {
|
||||
(content as any).partialJson += proxyEvent.delta;
|
||||
content.arguments = parseStreamingJson((content as any).partialJson) || {};
|
||||
partial.content[proxyEvent.contentIndex] = { ...content }; // Trigger reactivity
|
||||
return {
|
||||
type: "toolcall_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received toolcall_delta for non-toolCall content");
|
||||
}
|
||||
|
||||
case "toolcall_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "toolCall") {
|
||||
delete (content as any).partialJson;
|
||||
return {
|
||||
type: "toolcall_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
toolCall: content,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
case "done":
|
||||
partial.stopReason = proxyEvent.reason;
|
||||
partial.usage = proxyEvent.usage;
|
||||
return { type: "done", reason: proxyEvent.reason, message: partial };
|
||||
|
||||
case "error":
|
||||
partial.stopReason = proxyEvent.reason;
|
||||
partial.errorMessage = proxyEvent.errorMessage;
|
||||
partial.usage = proxyEvent.usage;
|
||||
return { type: "error", reason: proxyEvent.reason, error: partial };
|
||||
|
||||
default: {
|
||||
const _exhaustiveCheck: never = proxyEvent;
|
||||
console.warn(`Unhandled proxy event type: ${(proxyEvent as any).type}`);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
}
|
||||
310
packages/agent/src/types.ts
Normal file
310
packages/agent/src/types.ts
Normal file
|
|
@ -0,0 +1,310 @@
|
|||
import type {
|
||||
AssistantMessage,
|
||||
AssistantMessageEvent,
|
||||
ImageContent,
|
||||
Message,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
streamSimple,
|
||||
TextContent,
|
||||
Tool,
|
||||
ToolResultMessage,
|
||||
} from "@jaeswift/jae-ai";
|
||||
import type { Static, TSchema } from "@sinclair/typebox";
|
||||
|
||||
/**
|
||||
* Stream function used by the agent loop.
|
||||
*
|
||||
* Contract:
|
||||
* - Must not throw or return a rejected promise for request/model/runtime failures.
|
||||
* - Must return an AssistantMessageEventStream.
|
||||
* - Failures must be encoded in the returned stream via protocol events and a
|
||||
* final AssistantMessage with stopReason "error" or "aborted" and errorMessage.
|
||||
*/
|
||||
export type StreamFn = (
|
||||
...args: Parameters<typeof streamSimple>
|
||||
) => ReturnType<typeof streamSimple> | Promise<ReturnType<typeof streamSimple>>;
|
||||
|
||||
/**
|
||||
* Configuration for how tool calls from a single assistant message are executed.
|
||||
*
|
||||
* - "sequential": each tool call is prepared, executed, and finalized before the next one starts.
|
||||
* - "parallel": tool calls are prepared sequentially, then allowed tools execute concurrently.
|
||||
* Final tool results are still emitted in assistant source order.
|
||||
*/
|
||||
export type ToolExecutionMode = "sequential" | "parallel";
|
||||
|
||||
/** A single tool call content block emitted by an assistant message. */
|
||||
export type AgentToolCall = Extract<AssistantMessage["content"][number], { type: "toolCall" }>;
|
||||
|
||||
/**
|
||||
* Result returned from `beforeToolCall`.
|
||||
*
|
||||
* Returning `{ block: true }` prevents the tool from executing. The loop emits an error tool result instead.
|
||||
* `reason` becomes the text shown in that error result. If omitted, a default blocked message is used.
|
||||
*/
|
||||
export interface BeforeToolCallResult {
|
||||
block?: boolean;
|
||||
reason?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Partial override returned from `afterToolCall`.
|
||||
*
|
||||
* Merge semantics are field-by-field:
|
||||
* - `content`: if provided, replaces the tool result content array in full
|
||||
* - `details`: if provided, replaces the tool result details value in full
|
||||
* - `isError`: if provided, replaces the tool result error flag
|
||||
*
|
||||
* Omitted fields keep the original executed tool result values.
|
||||
* There is no deep merge for `content` or `details`.
|
||||
*/
|
||||
export interface AfterToolCallResult {
|
||||
content?: (TextContent | ImageContent)[];
|
||||
details?: unknown;
|
||||
isError?: boolean;
|
||||
}
|
||||
|
||||
/** Context passed to `beforeToolCall`. */
|
||||
export interface BeforeToolCallContext {
|
||||
/** The assistant message that requested the tool call. */
|
||||
assistantMessage: AssistantMessage;
|
||||
/** The raw tool call block from `assistantMessage.content`. */
|
||||
toolCall: AgentToolCall;
|
||||
/** Validated tool arguments for the target tool schema. */
|
||||
args: unknown;
|
||||
/** Current agent context at the time the tool call is prepared. */
|
||||
context: AgentContext;
|
||||
}
|
||||
|
||||
/** Context passed to `afterToolCall`. */
|
||||
export interface AfterToolCallContext {
|
||||
/** The assistant message that requested the tool call. */
|
||||
assistantMessage: AssistantMessage;
|
||||
/** The raw tool call block from `assistantMessage.content`. */
|
||||
toolCall: AgentToolCall;
|
||||
/** Validated tool arguments for the target tool schema. */
|
||||
args: unknown;
|
||||
/** The executed tool result before any `afterToolCall` overrides are applied. */
|
||||
result: AgentToolResult<any>;
|
||||
/** Whether the executed tool result is currently treated as an error. */
|
||||
isError: boolean;
|
||||
/** Current agent context at the time the tool call is finalized. */
|
||||
context: AgentContext;
|
||||
}
|
||||
|
||||
export interface AgentLoopConfig extends SimpleStreamOptions {
|
||||
model: Model<any>;
|
||||
|
||||
/**
|
||||
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
|
||||
*
|
||||
* Each AgentMessage must be converted to a UserMessage, AssistantMessage, or ToolResultMessage
|
||||
* that the LLM can understand. AgentMessages that cannot be converted (e.g., UI-only notifications,
|
||||
* status messages) should be filtered out.
|
||||
*
|
||||
* Contract: must not throw or reject. Return a safe fallback value instead.
|
||||
* Throwing interrupts the low-level agent loop without producing a normal event sequence.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* convertToLlm: (messages) => messages.flatMap(m => {
|
||||
* if (m.role === "custom") {
|
||||
* // Convert custom message to user message
|
||||
* return [{ role: "user", content: m.content, timestamp: m.timestamp }];
|
||||
* }
|
||||
* if (m.role === "notification") {
|
||||
* // Filter out UI-only messages
|
||||
* return [];
|
||||
* }
|
||||
* // Pass through standard LLM messages
|
||||
* return [m];
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
||||
|
||||
/**
|
||||
* Optional transform applied to the context before `convertToLlm`.
|
||||
*
|
||||
* Use this for operations that work at the AgentMessage level:
|
||||
* - Context window management (pruning old messages)
|
||||
* - Injecting context from external sources
|
||||
*
|
||||
* Contract: must not throw or reject. Return the original messages or another
|
||||
* safe fallback value instead.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* transformContext: async (messages) => {
|
||||
* if (estimateTokens(messages) > MAX_TOKENS) {
|
||||
* return pruneOldMessages(messages);
|
||||
* }
|
||||
* return messages;
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
||||
|
||||
/**
|
||||
* Resolves an API key dynamically for each LLM call.
|
||||
*
|
||||
* Useful for short-lived OAuth tokens (e.g., GitHub Copilot) that may expire
|
||||
* during long-running tool execution phases.
|
||||
*
|
||||
* Contract: must not throw or reject. Return undefined when no key is available.
|
||||
*/
|
||||
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
||||
|
||||
/**
|
||||
* Returns steering messages to inject into the conversation mid-run.
|
||||
*
|
||||
* Called after the current assistant turn finishes executing its tool calls.
|
||||
* If messages are returned, they are added to the context before the next LLM call.
|
||||
* Tool calls from the current assistant message are not skipped.
|
||||
*
|
||||
* Use this for "steering" the agent while it's working.
|
||||
*
|
||||
* Contract: must not throw or reject. Return [] when no steering messages are available.
|
||||
*/
|
||||
getSteeringMessages?: () => Promise<AgentMessage[]>;
|
||||
|
||||
/**
|
||||
* Returns follow-up messages to process after the agent would otherwise stop.
|
||||
*
|
||||
* Called when the agent has no more tool calls and no steering messages.
|
||||
* If messages are returned, they're added to the context and the agent
|
||||
* continues with another turn.
|
||||
*
|
||||
* Use this for follow-up messages that should wait until the agent finishes.
|
||||
*
|
||||
* Contract: must not throw or reject. Return [] when no follow-up messages are available.
|
||||
*/
|
||||
getFollowUpMessages?: () => Promise<AgentMessage[]>;
|
||||
|
||||
/**
|
||||
* Tool execution mode.
|
||||
* - "sequential": execute tool calls one by one
|
||||
* - "parallel": preflight tool calls sequentially, then execute allowed tools concurrently
|
||||
*
|
||||
* Default: "parallel"
|
||||
*/
|
||||
toolExecution?: ToolExecutionMode;
|
||||
|
||||
/**
|
||||
* Called before a tool is executed, after arguments have been validated.
|
||||
*
|
||||
* Return `{ block: true }` to prevent execution. The loop emits an error tool result instead.
|
||||
* The hook receives the agent abort signal and is responsible for honoring it.
|
||||
*/
|
||||
beforeToolCall?: (context: BeforeToolCallContext, signal?: AbortSignal) => Promise<BeforeToolCallResult | undefined>;
|
||||
|
||||
/**
|
||||
* Called after a tool finishes executing, before final tool events are emitted.
|
||||
*
|
||||
* Return an `AfterToolCallResult` to override parts of the executed tool result:
|
||||
* - `content` replaces the full content array
|
||||
* - `details` replaces the full details payload
|
||||
* - `isError` replaces the error flag
|
||||
*
|
||||
* Any omitted fields keep their original values. No deep merge is performed.
|
||||
* The hook receives the agent abort signal and is responsible for honoring it.
|
||||
*/
|
||||
afterToolCall?: (context: AfterToolCallContext, signal?: AbortSignal) => Promise<AfterToolCallResult | undefined>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Thinking/reasoning level for models that support it.
|
||||
* Note: "xhigh" is only supported by OpenAI gpt-5.1-codex-max, gpt-5.2, gpt-5.2-codex, gpt-5.3, and gpt-5.3-codex models.
|
||||
*/
|
||||
export type ThinkingLevel = "off" | "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
|
||||
/**
|
||||
* Extensible interface for custom app messages.
|
||||
* Apps can extend via declaration merging:
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* declare module "@mariozechner/agent" {
|
||||
* interface CustomAgentMessages {
|
||||
* artifact: ArtifactMessage;
|
||||
* notification: NotificationMessage;
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export interface CustomAgentMessages {
|
||||
// Empty by default - apps extend via declaration merging
|
||||
}
|
||||
|
||||
/**
|
||||
* AgentMessage: Union of LLM messages + custom messages.
|
||||
* This abstraction allows apps to add custom message types while maintaining
|
||||
* type safety and compatibility with the base LLM messages.
|
||||
*/
|
||||
export type AgentMessage = Message | CustomAgentMessages[keyof CustomAgentMessages];
|
||||
|
||||
/**
|
||||
* Agent state containing all configuration and conversation data.
|
||||
*/
|
||||
export interface AgentState {
|
||||
systemPrompt: string;
|
||||
model: Model<any>;
|
||||
thinkingLevel: ThinkingLevel;
|
||||
tools: AgentTool<any>[];
|
||||
messages: AgentMessage[]; // Can include attachments + custom message types
|
||||
isStreaming: boolean;
|
||||
streamMessage: AgentMessage | null;
|
||||
pendingToolCalls: Set<string>;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface AgentToolResult<T> {
|
||||
// Content blocks supporting text and images
|
||||
content: (TextContent | ImageContent)[];
|
||||
// Details to be displayed in a UI or logged
|
||||
details: T;
|
||||
}
|
||||
|
||||
// Callback for streaming tool execution updates
|
||||
export type AgentToolUpdateCallback<T = any> = (partialResult: AgentToolResult<T>) => void;
|
||||
|
||||
// AgentTool extends Tool but adds the execute function
|
||||
export interface AgentTool<TParameters extends TSchema = TSchema, TDetails = any> extends Tool<TParameters> {
|
||||
// A human-readable label for the tool to be displayed in UI
|
||||
label: string;
|
||||
execute: (
|
||||
toolCallId: string,
|
||||
params: Static<TParameters>,
|
||||
signal?: AbortSignal,
|
||||
onUpdate?: AgentToolUpdateCallback<TDetails>,
|
||||
) => Promise<AgentToolResult<TDetails>>;
|
||||
}
|
||||
|
||||
// AgentContext is like Context but uses AgentTool
|
||||
export interface AgentContext {
|
||||
systemPrompt: string;
|
||||
messages: AgentMessage[];
|
||||
tools?: AgentTool<any>[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Events emitted by the Agent for UI updates.
|
||||
* These events provide fine-grained lifecycle information for messages, turns, and tool executions.
|
||||
*/
|
||||
export type AgentEvent =
|
||||
// Agent lifecycle
|
||||
| { type: "agent_start" }
|
||||
| { type: "agent_end"; messages: AgentMessage[] }
|
||||
// Turn lifecycle - a turn is one assistant response + any tool calls/results
|
||||
| { type: "turn_start" }
|
||||
| { type: "turn_end"; message: AgentMessage; toolResults: ToolResultMessage[] }
|
||||
// Message lifecycle - emitted for user, assistant, and toolResult messages
|
||||
| { type: "message_start"; message: AgentMessage }
|
||||
// Only emitted for assistant messages during streaming
|
||||
| { type: "message_update"; message: AgentMessage; assistantMessageEvent: AssistantMessageEvent }
|
||||
| { type: "message_end"; message: AgentMessage }
|
||||
// Tool execution lifecycle
|
||||
| { type: "tool_execution_start"; toolCallId: string; toolName: string; args: any }
|
||||
| { type: "tool_execution_update"; toolCallId: string; toolName: string; args: any; partialResult: any }
|
||||
| { type: "tool_execution_end"; toolCallId: string; toolName: string; result: any; isError: boolean };
|
||||
616
packages/agent/test/agent-loop.test.ts
Normal file
616
packages/agent/test/agent-loop.test.ts
Normal file
|
|
@ -0,0 +1,616 @@
|
|||
import {
|
||||
type AssistantMessage,
|
||||
type AssistantMessageEvent,
|
||||
EventStream,
|
||||
type Message,
|
||||
type Model,
|
||||
type UserMessage,
|
||||
} from "@jaeswift/jae-ai";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { agentLoop, agentLoopContinue } from "../src/agent-loop.js";
|
||||
import type { AgentContext, AgentEvent, AgentLoopConfig, AgentMessage, AgentTool } from "../src/types.js";
|
||||
|
||||
// Mock stream for testing - mimics MockAssistantStream
|
||||
class MockAssistantStream extends EventStream<AssistantMessageEvent, AssistantMessage> {
|
||||
constructor() {
|
||||
super(
|
||||
(event) => event.type === "done" || event.type === "error",
|
||||
(event) => {
|
||||
if (event.type === "done") return event.message;
|
||||
if (event.type === "error") return event.error;
|
||||
throw new Error("Unexpected event type");
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
function createUsage() {
|
||||
return {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
}
|
||||
|
||||
function createModel(): Model<"openai-responses"> {
|
||||
return {
|
||||
id: "mock",
|
||||
name: "mock",
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
baseUrl: "https://example.invalid",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 8192,
|
||||
maxTokens: 2048,
|
||||
};
|
||||
}
|
||||
|
||||
function createAssistantMessage(
|
||||
content: AssistantMessage["content"],
|
||||
stopReason: AssistantMessage["stopReason"] = "stop",
|
||||
): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content,
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
model: "mock",
|
||||
usage: createUsage(),
|
||||
stopReason,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
function createUserMessage(text: string): UserMessage {
|
||||
return {
|
||||
role: "user",
|
||||
content: text,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
// Simple identity converter for tests - just passes through standard messages
|
||||
function identityConverter(messages: AgentMessage[]): Message[] {
|
||||
return messages.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult") as Message[];
|
||||
}
|
||||
|
||||
describe("agentLoop with AgentMessage", () => {
|
||||
it("should emit events with AgentMessage types", async () => {
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const userPrompt: AgentMessage = createUserMessage("Hello");
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: identityConverter,
|
||||
};
|
||||
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
const message = createAssistantMessage([{ type: "text", text: "Hi there!" }]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoop([userPrompt], context, config, undefined, streamFn);
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const messages = await stream.result();
|
||||
|
||||
// Should have user message and assistant message
|
||||
expect(messages.length).toBe(2);
|
||||
expect(messages[0].role).toBe("user");
|
||||
expect(messages[1].role).toBe("assistant");
|
||||
|
||||
// Verify event sequence
|
||||
const eventTypes = events.map((e) => e.type);
|
||||
expect(eventTypes).toContain("agent_start");
|
||||
expect(eventTypes).toContain("turn_start");
|
||||
expect(eventTypes).toContain("message_start");
|
||||
expect(eventTypes).toContain("message_end");
|
||||
expect(eventTypes).toContain("turn_end");
|
||||
expect(eventTypes).toContain("agent_end");
|
||||
});
|
||||
|
||||
it("should handle custom message types via convertToLlm", async () => {
|
||||
// Create a custom message type
|
||||
interface CustomNotification {
|
||||
role: "notification";
|
||||
text: string;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
const notification: CustomNotification = {
|
||||
role: "notification",
|
||||
text: "This is a notification",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [notification as unknown as AgentMessage], // Custom message in context
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const userPrompt: AgentMessage = createUserMessage("Hello");
|
||||
|
||||
let convertedMessages: Message[] = [];
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: (messages) => {
|
||||
// Filter out notifications, convert rest
|
||||
convertedMessages = messages
|
||||
.filter((m) => (m as { role: string }).role !== "notification")
|
||||
.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult") as Message[];
|
||||
return convertedMessages;
|
||||
},
|
||||
};
|
||||
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
const message = createAssistantMessage([{ type: "text", text: "Response" }]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoop([userPrompt], context, config, undefined, streamFn);
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// The notification should have been filtered out in convertToLlm
|
||||
expect(convertedMessages.length).toBe(1); // Only user message
|
||||
expect(convertedMessages[0].role).toBe("user");
|
||||
});
|
||||
|
||||
it("should apply transformContext before convertToLlm", async () => {
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [
|
||||
createUserMessage("old message 1"),
|
||||
createAssistantMessage([{ type: "text", text: "old response 1" }]),
|
||||
createUserMessage("old message 2"),
|
||||
createAssistantMessage([{ type: "text", text: "old response 2" }]),
|
||||
],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const userPrompt: AgentMessage = createUserMessage("new message");
|
||||
|
||||
let transformedMessages: AgentMessage[] = [];
|
||||
let convertedMessages: Message[] = [];
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
transformContext: async (messages) => {
|
||||
// Keep only last 2 messages (prune old ones)
|
||||
transformedMessages = messages.slice(-2);
|
||||
return transformedMessages;
|
||||
},
|
||||
convertToLlm: (messages) => {
|
||||
convertedMessages = messages.filter(
|
||||
(m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult",
|
||||
) as Message[];
|
||||
return convertedMessages;
|
||||
},
|
||||
};
|
||||
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
const message = createAssistantMessage([{ type: "text", text: "Response" }]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
const stream = agentLoop([userPrompt], context, config, undefined, streamFn);
|
||||
|
||||
for await (const _ of stream) {
|
||||
// consume
|
||||
}
|
||||
|
||||
// transformContext should have been called first, keeping only last 2
|
||||
expect(transformedMessages.length).toBe(2);
|
||||
// Then convertToLlm receives the pruned messages
|
||||
expect(convertedMessages.length).toBe(2);
|
||||
});
|
||||
|
||||
it("should handle tool calls and results", async () => {
|
||||
const toolSchema = Type.Object({ value: Type.String() });
|
||||
const executed: string[] = [];
|
||||
const tool: AgentTool<typeof toolSchema, { value: string }> = {
|
||||
name: "echo",
|
||||
label: "Echo",
|
||||
description: "Echo tool",
|
||||
parameters: toolSchema,
|
||||
async execute(_toolCallId, params) {
|
||||
executed.push(params.value);
|
||||
return {
|
||||
content: [{ type: "text", text: `echoed: ${params.value}` }],
|
||||
details: { value: params.value },
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "",
|
||||
messages: [],
|
||||
tools: [tool],
|
||||
};
|
||||
|
||||
const userPrompt: AgentMessage = createUserMessage("echo something");
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: identityConverter,
|
||||
};
|
||||
|
||||
let callIndex = 0;
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
if (callIndex === 0) {
|
||||
// First call: return tool call
|
||||
const message = createAssistantMessage(
|
||||
[{ type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "hello" } }],
|
||||
"toolUse",
|
||||
);
|
||||
stream.push({ type: "done", reason: "toolUse", message });
|
||||
} else {
|
||||
// Second call: return final response
|
||||
const message = createAssistantMessage([{ type: "text", text: "done" }]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
}
|
||||
callIndex++;
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoop([userPrompt], context, config, undefined, streamFn);
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// Tool should have been executed
|
||||
expect(executed).toEqual(["hello"]);
|
||||
|
||||
// Should have tool execution events
|
||||
const toolStart = events.find((e) => e.type === "tool_execution_start");
|
||||
const toolEnd = events.find((e) => e.type === "tool_execution_end");
|
||||
expect(toolStart).toBeDefined();
|
||||
expect(toolEnd).toBeDefined();
|
||||
if (toolEnd?.type === "tool_execution_end") {
|
||||
expect(toolEnd.isError).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it("should execute tool calls in parallel and emit tool results in source order", async () => {
|
||||
const toolSchema = Type.Object({ value: Type.String() });
|
||||
let firstResolved = false;
|
||||
let parallelObserved = false;
|
||||
let releaseFirst: (() => void) | undefined;
|
||||
const firstDone = new Promise<void>((resolve) => {
|
||||
releaseFirst = resolve;
|
||||
});
|
||||
|
||||
const tool: AgentTool<typeof toolSchema, { value: string }> = {
|
||||
name: "echo",
|
||||
label: "Echo",
|
||||
description: "Echo tool",
|
||||
parameters: toolSchema,
|
||||
async execute(_toolCallId, params) {
|
||||
if (params.value === "first") {
|
||||
await firstDone;
|
||||
firstResolved = true;
|
||||
}
|
||||
if (params.value === "second" && !firstResolved) {
|
||||
parallelObserved = true;
|
||||
}
|
||||
return {
|
||||
content: [{ type: "text", text: `echoed: ${params.value}` }],
|
||||
details: { value: params.value },
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "",
|
||||
messages: [],
|
||||
tools: [tool],
|
||||
};
|
||||
|
||||
const userPrompt: AgentMessage = createUserMessage("echo both");
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: identityConverter,
|
||||
toolExecution: "parallel",
|
||||
};
|
||||
|
||||
let callIndex = 0;
|
||||
const stream = agentLoop([userPrompt], context, config, undefined, () => {
|
||||
const mockStream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
if (callIndex === 0) {
|
||||
const message = createAssistantMessage(
|
||||
[
|
||||
{ type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "first" } },
|
||||
{ type: "toolCall", id: "tool-2", name: "echo", arguments: { value: "second" } },
|
||||
],
|
||||
"toolUse",
|
||||
);
|
||||
mockStream.push({ type: "done", reason: "toolUse", message });
|
||||
setTimeout(() => releaseFirst?.(), 20);
|
||||
} else {
|
||||
const message = createAssistantMessage([{ type: "text", text: "done" }]);
|
||||
mockStream.push({ type: "done", reason: "stop", message });
|
||||
}
|
||||
callIndex++;
|
||||
});
|
||||
return mockStream;
|
||||
});
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const toolResultIds = events.flatMap((event) => {
|
||||
if (event.type !== "message_end" || event.message.role !== "toolResult") {
|
||||
return [];
|
||||
}
|
||||
return [event.message.toolCallId];
|
||||
});
|
||||
|
||||
expect(parallelObserved).toBe(true);
|
||||
expect(toolResultIds).toEqual(["tool-1", "tool-2"]);
|
||||
});
|
||||
|
||||
it("should inject queued messages after all tool calls complete", async () => {
|
||||
const toolSchema = Type.Object({ value: Type.String() });
|
||||
const executed: string[] = [];
|
||||
const tool: AgentTool<typeof toolSchema, { value: string }> = {
|
||||
name: "echo",
|
||||
label: "Echo",
|
||||
description: "Echo tool",
|
||||
parameters: toolSchema,
|
||||
async execute(_toolCallId, params) {
|
||||
executed.push(params.value);
|
||||
return {
|
||||
content: [{ type: "text", text: `ok:${params.value}` }],
|
||||
details: { value: params.value },
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "",
|
||||
messages: [],
|
||||
tools: [tool],
|
||||
};
|
||||
|
||||
const userPrompt: AgentMessage = createUserMessage("start");
|
||||
const queuedUserMessage: AgentMessage = createUserMessage("interrupt");
|
||||
|
||||
let queuedDelivered = false;
|
||||
let callIndex = 0;
|
||||
let sawInterruptInContext = false;
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: identityConverter,
|
||||
toolExecution: "sequential",
|
||||
getSteeringMessages: async () => {
|
||||
// Return steering message after tool execution has started.
|
||||
if (executed.length >= 1 && !queuedDelivered) {
|
||||
queuedDelivered = true;
|
||||
return [queuedUserMessage];
|
||||
}
|
||||
return [];
|
||||
},
|
||||
};
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoop([userPrompt], context, config, undefined, (_model, ctx, _options) => {
|
||||
// Check if interrupt message is in context on second call
|
||||
if (callIndex === 1) {
|
||||
sawInterruptInContext = ctx.messages.some(
|
||||
(m) => m.role === "user" && typeof m.content === "string" && m.content === "interrupt",
|
||||
);
|
||||
}
|
||||
|
||||
const mockStream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
if (callIndex === 0) {
|
||||
// First call: return two tool calls
|
||||
const message = createAssistantMessage(
|
||||
[
|
||||
{ type: "toolCall", id: "tool-1", name: "echo", arguments: { value: "first" } },
|
||||
{ type: "toolCall", id: "tool-2", name: "echo", arguments: { value: "second" } },
|
||||
],
|
||||
"toolUse",
|
||||
);
|
||||
mockStream.push({ type: "done", reason: "toolUse", message });
|
||||
} else {
|
||||
// Second call: return final response
|
||||
const message = createAssistantMessage([{ type: "text", text: "done" }]);
|
||||
mockStream.push({ type: "done", reason: "stop", message });
|
||||
}
|
||||
callIndex++;
|
||||
});
|
||||
return mockStream;
|
||||
});
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// Both tools should execute before steering is injected
|
||||
expect(executed).toEqual(["first", "second"]);
|
||||
|
||||
const toolEnds = events.filter(
|
||||
(e): e is Extract<AgentEvent, { type: "tool_execution_end" }> => e.type === "tool_execution_end",
|
||||
);
|
||||
expect(toolEnds.length).toBe(2);
|
||||
expect(toolEnds[0].isError).toBe(false);
|
||||
expect(toolEnds[1].isError).toBe(false);
|
||||
|
||||
// Queued message should appear in events after both tool result messages
|
||||
const eventSequence = events.flatMap((event) => {
|
||||
if (event.type !== "message_start") return [];
|
||||
if (event.message.role === "toolResult") return [`tool:${event.message.toolCallId}`];
|
||||
if (event.message.role === "user" && typeof event.message.content === "string") {
|
||||
return [event.message.content];
|
||||
}
|
||||
return [];
|
||||
});
|
||||
expect(eventSequence).toContain("interrupt");
|
||||
expect(eventSequence.indexOf("tool:tool-1")).toBeLessThan(eventSequence.indexOf("interrupt"));
|
||||
expect(eventSequence.indexOf("tool:tool-2")).toBeLessThan(eventSequence.indexOf("interrupt"));
|
||||
|
||||
// Interrupt message should be in context when second LLM call is made
|
||||
expect(sawInterruptInContext).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("agentLoopContinue with AgentMessage", () => {
|
||||
it("should throw when context has no messages", () => {
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: identityConverter,
|
||||
};
|
||||
|
||||
expect(() => agentLoopContinue(context, config)).toThrow("Cannot continue: no messages in context");
|
||||
});
|
||||
|
||||
it("should continue from existing context without emitting user message events", async () => {
|
||||
const userMessage: AgentMessage = createUserMessage("Hello");
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [userMessage],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: identityConverter,
|
||||
};
|
||||
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
const message = createAssistantMessage([{ type: "text", text: "Response" }]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoopContinue(context, config, undefined, streamFn);
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const messages = await stream.result();
|
||||
|
||||
// Should only return the new assistant message (not the existing user message)
|
||||
expect(messages.length).toBe(1);
|
||||
expect(messages[0].role).toBe("assistant");
|
||||
|
||||
// Should NOT have user message events (that's the key difference from agentLoop)
|
||||
const messageEndEvents = events.filter((e) => e.type === "message_end");
|
||||
expect(messageEndEvents.length).toBe(1);
|
||||
expect((messageEndEvents[0] as any).message.role).toBe("assistant");
|
||||
});
|
||||
|
||||
it("should allow custom message types as last message (caller responsibility)", async () => {
|
||||
// Custom message that will be converted to user message by convertToLlm
|
||||
interface CustomMessage {
|
||||
role: "custom";
|
||||
text: string;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
const customMessage: CustomMessage = {
|
||||
role: "custom",
|
||||
text: "Hook content",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "You are helpful.",
|
||||
messages: [customMessage as unknown as AgentMessage],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: (messages) => {
|
||||
// Convert custom to user message
|
||||
return messages
|
||||
.map((m) => {
|
||||
if ((m as any).role === "custom") {
|
||||
return {
|
||||
role: "user" as const,
|
||||
content: (m as any).text,
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
}
|
||||
return m;
|
||||
})
|
||||
.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult") as Message[];
|
||||
},
|
||||
};
|
||||
|
||||
const streamFn = () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
const message = createAssistantMessage([{ type: "text", text: "Response to custom message" }]);
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
});
|
||||
return stream;
|
||||
};
|
||||
|
||||
// Should not throw - the custom message will be converted to user message
|
||||
const stream = agentLoopContinue(context, config, undefined, streamFn);
|
||||
|
||||
const events: AgentEvent[] = [];
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const messages = await stream.result();
|
||||
expect(messages.length).toBe(1);
|
||||
expect(messages[0].role).toBe("assistant");
|
||||
});
|
||||
});
|
||||
340
packages/agent/test/agent.test.ts
Normal file
340
packages/agent/test/agent.test.ts
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
import { type AssistantMessage, type AssistantMessageEvent, EventStream, getModel } from "@jaeswift/jae-ai";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { Agent } from "../src/index.js";
|
||||
|
||||
// Mock stream that mimics AssistantMessageEventStream
|
||||
class MockAssistantStream extends EventStream<AssistantMessageEvent, AssistantMessage> {
|
||||
constructor() {
|
||||
super(
|
||||
(event) => event.type === "done" || event.type === "error",
|
||||
(event) => {
|
||||
if (event.type === "done") return event.message;
|
||||
if (event.type === "error") return event.error;
|
||||
throw new Error("Unexpected event type");
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
function createAssistantMessage(text: string): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text }],
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
model: "mock",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
describe("Agent", () => {
|
||||
it("should create an agent instance with default state", () => {
|
||||
const agent = new Agent();
|
||||
|
||||
expect(agent.state).toBeDefined();
|
||||
expect(agent.state.systemPrompt).toBe("");
|
||||
expect(agent.state.model).toBeDefined();
|
||||
expect(agent.state.thinkingLevel).toBe("off");
|
||||
expect(agent.state.tools).toEqual([]);
|
||||
expect(agent.state.messages).toEqual([]);
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
expect(agent.state.streamMessage).toBe(null);
|
||||
expect(agent.state.pendingToolCalls).toEqual(new Set());
|
||||
expect(agent.state.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should create an agent instance with custom initial state", () => {
|
||||
const customModel = getModel("openai", "gpt-4o-mini");
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant.",
|
||||
model: customModel,
|
||||
thinkingLevel: "low",
|
||||
},
|
||||
});
|
||||
|
||||
expect(agent.state.systemPrompt).toBe("You are a helpful assistant.");
|
||||
expect(agent.state.model).toBe(customModel);
|
||||
expect(agent.state.thinkingLevel).toBe("low");
|
||||
});
|
||||
|
||||
it("should subscribe to events", () => {
|
||||
const agent = new Agent();
|
||||
|
||||
let eventCount = 0;
|
||||
const unsubscribe = agent.subscribe((_event) => {
|
||||
eventCount++;
|
||||
});
|
||||
|
||||
// No initial event on subscribe
|
||||
expect(eventCount).toBe(0);
|
||||
|
||||
// State mutators don't emit events
|
||||
agent.setSystemPrompt("Test prompt");
|
||||
expect(eventCount).toBe(0);
|
||||
expect(agent.state.systemPrompt).toBe("Test prompt");
|
||||
|
||||
// Unsubscribe should work
|
||||
unsubscribe();
|
||||
agent.setSystemPrompt("Another prompt");
|
||||
expect(eventCount).toBe(0); // Should not increase
|
||||
});
|
||||
|
||||
it("should update state with mutators", () => {
|
||||
const agent = new Agent();
|
||||
|
||||
// Test setSystemPrompt
|
||||
agent.setSystemPrompt("Custom prompt");
|
||||
expect(agent.state.systemPrompt).toBe("Custom prompt");
|
||||
|
||||
// Test setModel
|
||||
const newModel = getModel("google", "gemini-2.5-flash");
|
||||
agent.setModel(newModel);
|
||||
expect(agent.state.model).toBe(newModel);
|
||||
|
||||
// Test setThinkingLevel
|
||||
agent.setThinkingLevel("high");
|
||||
expect(agent.state.thinkingLevel).toBe("high");
|
||||
|
||||
// Test setTools
|
||||
const tools = [{ name: "test", description: "test tool" } as any];
|
||||
agent.setTools(tools);
|
||||
expect(agent.state.tools).toBe(tools);
|
||||
|
||||
// Test replaceMessages
|
||||
const messages = [{ role: "user" as const, content: "Hello", timestamp: Date.now() }];
|
||||
agent.replaceMessages(messages);
|
||||
expect(agent.state.messages).toEqual(messages);
|
||||
expect(agent.state.messages).not.toBe(messages); // Should be a copy
|
||||
|
||||
// Test appendMessage
|
||||
const newMessage = { role: "assistant" as const, content: [{ type: "text" as const, text: "Hi" }] };
|
||||
agent.appendMessage(newMessage as any);
|
||||
expect(agent.state.messages).toHaveLength(2);
|
||||
expect(agent.state.messages[1]).toBe(newMessage);
|
||||
|
||||
// Test clearMessages
|
||||
agent.clearMessages();
|
||||
expect(agent.state.messages).toEqual([]);
|
||||
});
|
||||
|
||||
it("should support steering message queue", async () => {
|
||||
const agent = new Agent();
|
||||
|
||||
const message = { role: "user" as const, content: "Steering message", timestamp: Date.now() };
|
||||
agent.steer(message);
|
||||
|
||||
// The message is queued but not yet in state.messages
|
||||
expect(agent.state.messages).not.toContainEqual(message);
|
||||
});
|
||||
|
||||
it("should support follow-up message queue", async () => {
|
||||
const agent = new Agent();
|
||||
|
||||
const message = { role: "user" as const, content: "Follow-up message", timestamp: Date.now() };
|
||||
agent.followUp(message);
|
||||
|
||||
// The message is queued but not yet in state.messages
|
||||
expect(agent.state.messages).not.toContainEqual(message);
|
||||
});
|
||||
|
||||
it("should handle abort controller", () => {
|
||||
const agent = new Agent();
|
||||
|
||||
// Should not throw even if nothing is running
|
||||
expect(() => agent.abort()).not.toThrow();
|
||||
});
|
||||
|
||||
it("should throw when prompt() called while streaming", async () => {
|
||||
let abortSignal: AbortSignal | undefined;
|
||||
const agent = new Agent({
|
||||
// Use a stream function that responds to abort
|
||||
streamFn: (_model, _context, options) => {
|
||||
abortSignal = options?.signal;
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
stream.push({ type: "start", partial: createAssistantMessage("") });
|
||||
// Check abort signal periodically
|
||||
const checkAbort = () => {
|
||||
if (abortSignal?.aborted) {
|
||||
stream.push({ type: "error", reason: "aborted", error: createAssistantMessage("Aborted") });
|
||||
} else {
|
||||
setTimeout(checkAbort, 5);
|
||||
}
|
||||
};
|
||||
checkAbort();
|
||||
});
|
||||
return stream;
|
||||
},
|
||||
});
|
||||
|
||||
// Start first prompt (don't await, it will block until abort)
|
||||
const firstPrompt = agent.prompt("First message");
|
||||
|
||||
// Wait a tick for isStreaming to be set
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
expect(agent.state.isStreaming).toBe(true);
|
||||
|
||||
// Second prompt should reject
|
||||
await expect(agent.prompt("Second message")).rejects.toThrow(
|
||||
"Agent is already processing a prompt. Use steer() or followUp() to queue messages, or wait for completion.",
|
||||
);
|
||||
|
||||
// Cleanup - abort to stop the stream
|
||||
agent.abort();
|
||||
await firstPrompt.catch(() => {}); // Ignore abort error
|
||||
});
|
||||
|
||||
it("should throw when continue() called while streaming", async () => {
|
||||
let abortSignal: AbortSignal | undefined;
|
||||
const agent = new Agent({
|
||||
streamFn: (_model, _context, options) => {
|
||||
abortSignal = options?.signal;
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
stream.push({ type: "start", partial: createAssistantMessage("") });
|
||||
const checkAbort = () => {
|
||||
if (abortSignal?.aborted) {
|
||||
stream.push({ type: "error", reason: "aborted", error: createAssistantMessage("Aborted") });
|
||||
} else {
|
||||
setTimeout(checkAbort, 5);
|
||||
}
|
||||
};
|
||||
checkAbort();
|
||||
});
|
||||
return stream;
|
||||
},
|
||||
});
|
||||
|
||||
// Start first prompt
|
||||
const firstPrompt = agent.prompt("First message");
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
expect(agent.state.isStreaming).toBe(true);
|
||||
|
||||
// continue() should reject
|
||||
await expect(agent.continue()).rejects.toThrow(
|
||||
"Agent is already processing. Wait for completion before continuing.",
|
||||
);
|
||||
|
||||
// Cleanup
|
||||
agent.abort();
|
||||
await firstPrompt.catch(() => {});
|
||||
});
|
||||
|
||||
it("continue() should process queued follow-up messages after an assistant turn", async () => {
|
||||
const agent = new Agent({
|
||||
streamFn: () => {
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
stream.push({ type: "done", reason: "stop", message: createAssistantMessage("Processed") });
|
||||
});
|
||||
return stream;
|
||||
},
|
||||
});
|
||||
|
||||
agent.replaceMessages([
|
||||
{
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "Initial" }],
|
||||
timestamp: Date.now() - 10,
|
||||
},
|
||||
createAssistantMessage("Initial response"),
|
||||
]);
|
||||
|
||||
agent.followUp({
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "Queued follow-up" }],
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
|
||||
await expect(agent.continue()).resolves.toBeUndefined();
|
||||
|
||||
const hasQueuedFollowUp = agent.state.messages.some((message) => {
|
||||
if (message.role !== "user") return false;
|
||||
if (typeof message.content === "string") return message.content === "Queued follow-up";
|
||||
return message.content.some((part) => part.type === "text" && part.text === "Queued follow-up");
|
||||
});
|
||||
|
||||
expect(hasQueuedFollowUp).toBe(true);
|
||||
expect(agent.state.messages[agent.state.messages.length - 1].role).toBe("assistant");
|
||||
});
|
||||
|
||||
it("continue() should keep one-at-a-time steering semantics from assistant tail", async () => {
|
||||
let responseCount = 0;
|
||||
const agent = new Agent({
|
||||
streamFn: () => {
|
||||
const stream = new MockAssistantStream();
|
||||
responseCount++;
|
||||
queueMicrotask(() => {
|
||||
stream.push({
|
||||
type: "done",
|
||||
reason: "stop",
|
||||
message: createAssistantMessage(`Processed ${responseCount}`),
|
||||
});
|
||||
});
|
||||
return stream;
|
||||
},
|
||||
});
|
||||
|
||||
agent.replaceMessages([
|
||||
{
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "Initial" }],
|
||||
timestamp: Date.now() - 10,
|
||||
},
|
||||
createAssistantMessage("Initial response"),
|
||||
]);
|
||||
|
||||
agent.steer({
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "Steering 1" }],
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
agent.steer({
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "Steering 2" }],
|
||||
timestamp: Date.now() + 1,
|
||||
});
|
||||
|
||||
await expect(agent.continue()).resolves.toBeUndefined();
|
||||
|
||||
const recentMessages = agent.state.messages.slice(-4);
|
||||
expect(recentMessages.map((m) => m.role)).toEqual(["user", "assistant", "user", "assistant"]);
|
||||
expect(responseCount).toBe(2);
|
||||
});
|
||||
|
||||
it("forwards sessionId to streamFn options", async () => {
|
||||
let receivedSessionId: string | undefined;
|
||||
const agent = new Agent({
|
||||
sessionId: "session-abc",
|
||||
streamFn: (_model, _context, options) => {
|
||||
receivedSessionId = options?.sessionId;
|
||||
const stream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
const message = createAssistantMessage("ok");
|
||||
stream.push({ type: "done", reason: "stop", message });
|
||||
});
|
||||
return stream;
|
||||
},
|
||||
});
|
||||
|
||||
await agent.prompt("hello");
|
||||
expect(receivedSessionId).toBe("session-abc");
|
||||
|
||||
// Test setter
|
||||
agent.sessionId = "session-def";
|
||||
expect(agent.sessionId).toBe("session-def");
|
||||
|
||||
await agent.prompt("hello again");
|
||||
expect(receivedSessionId).toBe("session-def");
|
||||
});
|
||||
});
|
||||
287
packages/agent/test/bedrock-models.test.ts
Normal file
287
packages/agent/test/bedrock-models.test.ts
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
/**
|
||||
* A test suite to ensure Amazon Bedrock models work correctly with the agent loop.
|
||||
*
|
||||
* Some Bedrock models don't support all features (e.g., reasoning signatures).
|
||||
* This test suite verifies that the agent loop works with various Bedrock models.
|
||||
*
|
||||
* This test suite is not enabled by default unless AWS credentials and
|
||||
* `BEDROCK_EXTENSIVE_MODEL_TEST` environment variables are set.
|
||||
*
|
||||
* You can run this test suite with:
|
||||
* ```bash
|
||||
* $ AWS_REGION=us-east-1 BEDROCK_EXTENSIVE_MODEL_TEST=1 AWS_PROFILE=pi npm test -- ./test/bedrock-models.test.ts
|
||||
* ```
|
||||
*
|
||||
* ## Known Issues by Category
|
||||
*
|
||||
* 1. **Inference Profile Required**: Some models require an inference profile ARN instead of on-demand.
|
||||
* 2. **Invalid Model ID**: Model identifiers that don't exist in the current region.
|
||||
* 3. **Max Tokens Exceeded**: Model's maxTokens in our config exceeds the actual limit.
|
||||
* 4. **No Reasoning in User Messages**: Model rejects reasoning content when replayed in conversation.
|
||||
* 5. **Invalid Signature Format**: Model validates signature format (Anthropic newer models).
|
||||
*/
|
||||
|
||||
import type { AssistantMessage } from "@jaeswift/jae-ai";
|
||||
import { getModels } from "@jaeswift/jae-ai";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { Agent } from "../src/index.js";
|
||||
import { hasBedrockCredentials } from "./bedrock-utils.js";
|
||||
|
||||
// =============================================================================
|
||||
// Known Issue Categories
|
||||
// =============================================================================
|
||||
|
||||
/** Models that require inference profile ARN (not available on-demand in us-east-1) */
|
||||
const REQUIRES_INFERENCE_PROFILE = new Set([
|
||||
"anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"anthropic.claude-3-opus-20240229-v1:0",
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
]);
|
||||
|
||||
/** Models with invalid identifiers (not available in us-east-1 or don't exist) */
|
||||
const INVALID_MODEL_ID = new Set([
|
||||
"deepseek.v3-v1:0",
|
||||
"eu.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"eu.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"eu.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"qwen.qwen3-235b-a22b-2507-v1:0",
|
||||
"qwen.qwen3-coder-480b-a35b-v1:0",
|
||||
]);
|
||||
|
||||
/** Models where our maxTokens config exceeds the model's actual limit */
|
||||
const MAX_TOKENS_EXCEEDED = new Set([
|
||||
"us.meta.llama4-maverick-17b-instruct-v1:0",
|
||||
"us.meta.llama4-scout-17b-instruct-v1:0",
|
||||
]);
|
||||
|
||||
/**
|
||||
* Models that reject reasoning content in user messages (when replaying conversation).
|
||||
* These work for multi-turn but fail when synthetic thinking is injected.
|
||||
*/
|
||||
const NO_REASONING_IN_USER_MESSAGES = new Set([
|
||||
// Mistral models
|
||||
"mistral.ministral-3-14b-instruct",
|
||||
"mistral.ministral-3-8b-instruct",
|
||||
"mistral.mistral-large-2402-v1:0",
|
||||
"mistral.voxtral-mini-3b-2507",
|
||||
"mistral.voxtral-small-24b-2507",
|
||||
// Nvidia models
|
||||
"nvidia.nemotron-nano-12b-v2",
|
||||
"nvidia.nemotron-nano-9b-v2",
|
||||
// Qwen models
|
||||
"qwen.qwen3-coder-30b-a3b-v1:0",
|
||||
// Amazon Nova models
|
||||
"us.amazon.nova-lite-v1:0",
|
||||
"us.amazon.nova-micro-v1:0",
|
||||
"us.amazon.nova-premier-v1:0",
|
||||
"us.amazon.nova-pro-v1:0",
|
||||
// Meta Llama models
|
||||
"us.meta.llama3-2-11b-instruct-v1:0",
|
||||
"us.meta.llama3-2-1b-instruct-v1:0",
|
||||
"us.meta.llama3-2-3b-instruct-v1:0",
|
||||
"us.meta.llama3-2-90b-instruct-v1:0",
|
||||
"us.meta.llama3-3-70b-instruct-v1:0",
|
||||
// DeepSeek
|
||||
"us.deepseek.r1-v1:0",
|
||||
// Older Anthropic models
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
// Cohere models
|
||||
"cohere.command-r-plus-v1:0",
|
||||
"cohere.command-r-v1:0",
|
||||
// Google models
|
||||
"google.gemma-3-27b-it",
|
||||
"google.gemma-3-4b-it",
|
||||
// Non-Anthropic models that don't support signatures (now handled by omitting signature)
|
||||
// but still reject reasoning content in user messages
|
||||
"global.amazon.nova-2-lite-v1:0",
|
||||
"minimax.minimax-m2",
|
||||
"moonshot.kimi-k2-thinking",
|
||||
"openai.gpt-oss-120b-1:0",
|
||||
"openai.gpt-oss-20b-1:0",
|
||||
"openai.gpt-oss-safeguard-120b",
|
||||
"openai.gpt-oss-safeguard-20b",
|
||||
"qwen.qwen3-32b-v1:0",
|
||||
"qwen.qwen3-next-80b-a3b",
|
||||
"qwen.qwen3-vl-235b-a22b",
|
||||
]);
|
||||
|
||||
/**
|
||||
* Models that validate signature format (Anthropic newer models).
|
||||
* These work for multi-turn but fail when synthetic/invalid signature is injected.
|
||||
*/
|
||||
const VALIDATES_SIGNATURE_FORMAT = new Set([
|
||||
"global.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"global.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"global.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
"global.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
"us.anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"us.anthropic.claude-opus-4-20250514-v1:0",
|
||||
]);
|
||||
|
||||
/**
|
||||
* DeepSeek R1 fails multi-turn because it rejects reasoning in the replayed assistant message.
|
||||
*/
|
||||
const REJECTS_REASONING_ON_REPLAY = new Set(["us.deepseek.r1-v1:0"]);
|
||||
|
||||
// =============================================================================
|
||||
// Helper Functions
|
||||
// =============================================================================
|
||||
|
||||
function isModelUnavailable(modelId: string): boolean {
|
||||
return REQUIRES_INFERENCE_PROFILE.has(modelId) || INVALID_MODEL_ID.has(modelId) || MAX_TOKENS_EXCEEDED.has(modelId);
|
||||
}
|
||||
|
||||
function failsMultiTurnWithThinking(modelId: string): boolean {
|
||||
return REJECTS_REASONING_ON_REPLAY.has(modelId);
|
||||
}
|
||||
|
||||
function failsSyntheticSignature(modelId: string): boolean {
|
||||
return NO_REASONING_IN_USER_MESSAGES.has(modelId) || VALIDATES_SIGNATURE_FORMAT.has(modelId);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Tests
|
||||
// =============================================================================
|
||||
|
||||
describe("Amazon Bedrock Models - Agent Loop", () => {
|
||||
const shouldRunExtensiveTests = hasBedrockCredentials() && process.env.BEDROCK_EXTENSIVE_MODEL_TEST;
|
||||
|
||||
// Get all Amazon Bedrock models
|
||||
const allBedrockModels = getModels("amazon-bedrock");
|
||||
|
||||
if (shouldRunExtensiveTests) {
|
||||
for (const model of allBedrockModels) {
|
||||
const modelId = model.id;
|
||||
|
||||
describe(`Model: ${modelId}`, () => {
|
||||
// Skip entirely unavailable models
|
||||
const unavailable = isModelUnavailable(modelId);
|
||||
|
||||
it.skipIf(unavailable)("should handle basic text prompt", { timeout: 60_000 }, async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant. Be extremely concise.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
});
|
||||
|
||||
await agent.prompt("Reply with exactly: 'OK'");
|
||||
|
||||
if (agent.state.error) {
|
||||
throw new Error(`Basic prompt error: ${agent.state.error}`);
|
||||
}
|
||||
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
expect(agent.state.messages.length).toBe(2);
|
||||
|
||||
const assistantMessage = agent.state.messages[1];
|
||||
if (assistantMessage.role !== "assistant") throw new Error("Expected assistant message");
|
||||
|
||||
console.log(`${modelId}: OK`);
|
||||
});
|
||||
|
||||
// Skip if model is unavailable or known to fail multi-turn with thinking
|
||||
const skipMultiTurn = unavailable || failsMultiTurnWithThinking(modelId);
|
||||
|
||||
it.skipIf(skipMultiTurn)(
|
||||
"should handle multi-turn conversation with thinking content in history",
|
||||
{ timeout: 120_000 },
|
||||
async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant. Be extremely concise.",
|
||||
model,
|
||||
thinkingLevel: "medium",
|
||||
tools: [],
|
||||
},
|
||||
});
|
||||
|
||||
// First turn
|
||||
await agent.prompt("My name is Alice.");
|
||||
|
||||
if (agent.state.error) {
|
||||
throw new Error(`First turn error: ${agent.state.error}`);
|
||||
}
|
||||
|
||||
// Second turn - this should replay the first assistant message which may contain thinking
|
||||
await agent.prompt("What is my name?");
|
||||
|
||||
if (agent.state.error) {
|
||||
throw new Error(`Second turn error: ${agent.state.error}`);
|
||||
}
|
||||
|
||||
expect(agent.state.messages.length).toBe(4);
|
||||
console.log(`${modelId}: multi-turn OK`);
|
||||
},
|
||||
);
|
||||
|
||||
// Skip if model is unavailable or known to fail synthetic signature
|
||||
const skipSynthetic = unavailable || failsSyntheticSignature(modelId);
|
||||
|
||||
it.skipIf(skipSynthetic)(
|
||||
"should handle conversation with synthetic thinking signature in history",
|
||||
{ timeout: 60_000 },
|
||||
async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant. Be extremely concise.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
});
|
||||
|
||||
// Inject a message with a thinking block that has a signature
|
||||
const syntheticAssistantMessage: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "thinking",
|
||||
thinking: "I need to remember the user's name.",
|
||||
thinkingSignature: "synthetic-signature-123",
|
||||
},
|
||||
{ type: "text", text: "Nice to meet you, Alice!" },
|
||||
],
|
||||
api: "bedrock-converse-stream",
|
||||
provider: "amazon-bedrock",
|
||||
model: modelId,
|
||||
usage: {
|
||||
input: 10,
|
||||
output: 20,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 30,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
agent.replaceMessages([
|
||||
{ role: "user", content: "My name is Alice.", timestamp: Date.now() },
|
||||
syntheticAssistantMessage,
|
||||
]);
|
||||
|
||||
await agent.prompt("What is my name?");
|
||||
|
||||
if (agent.state.error) {
|
||||
throw new Error(`Synthetic signature error: ${agent.state.error}`);
|
||||
}
|
||||
|
||||
expect(agent.state.messages.length).toBe(4);
|
||||
console.log(`${modelId}: synthetic signature OK`);
|
||||
},
|
||||
);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
it.skip("skipped - set AWS credentials and BEDROCK_EXTENSIVE_MODEL_TEST=1 to run", () => {});
|
||||
}
|
||||
});
|
||||
18
packages/agent/test/bedrock-utils.ts
Normal file
18
packages/agent/test/bedrock-utils.ts
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
/**
|
||||
* Utility functions for Amazon Bedrock tests
|
||||
*/
|
||||
|
||||
/**
|
||||
* Check if any valid AWS credentials are configured for Bedrock.
|
||||
* Returns true if any of the following are set:
|
||||
* - AWS_PROFILE (named profile from ~/.aws/credentials)
|
||||
* - AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY (IAM keys)
|
||||
* - AWS_BEARER_TOKEN_BEDROCK (Bedrock API key)
|
||||
*/
|
||||
export function hasBedrockCredentials(): boolean {
|
||||
return !!(
|
||||
process.env.AWS_PROFILE ||
|
||||
(process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY) ||
|
||||
process.env.AWS_BEARER_TOKEN_BEDROCK
|
||||
);
|
||||
}
|
||||
511
packages/agent/test/e2e.test.ts
Normal file
511
packages/agent/test/e2e.test.ts
Normal file
|
|
@ -0,0 +1,511 @@
|
|||
import type { AssistantMessage, Model, ToolResultMessage, UserMessage } from "@jaeswift/jae-ai";
|
||||
import { getModel } from "@jaeswift/jae-ai";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { Agent } from "../src/index.js";
|
||||
import { hasBedrockCredentials } from "./bedrock-utils.js";
|
||||
import { calculateTool } from "./utils/calculate.js";
|
||||
|
||||
delete process.env.ANTHROPIC_OAUTH_TOKEN;
|
||||
|
||||
async function basicPrompt(model: Model<any>) {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant. Keep your responses concise.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
});
|
||||
|
||||
await agent.prompt("What is 2+2? Answer with just the number.");
|
||||
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
expect(agent.state.messages.length).toBe(2);
|
||||
expect(agent.state.messages[0].role).toBe("user");
|
||||
expect(agent.state.messages[1].role).toBe("assistant");
|
||||
|
||||
const assistantMessage = agent.state.messages[1];
|
||||
if (assistantMessage.role !== "assistant") throw new Error("Expected assistant message");
|
||||
expect(assistantMessage.content.length).toBeGreaterThan(0);
|
||||
|
||||
const textContent = assistantMessage.content.find((c) => c.type === "text");
|
||||
expect(textContent).toBeDefined();
|
||||
if (textContent?.type !== "text") throw new Error("Expected text content");
|
||||
expect(textContent.text).toContain("4");
|
||||
}
|
||||
|
||||
async function toolExecution(model: Model<any>) {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant. Always use the calculator tool for math.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [calculateTool],
|
||||
},
|
||||
});
|
||||
|
||||
await agent.prompt("Calculate 123 * 456 using the calculator tool.");
|
||||
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
expect(agent.state.messages.length).toBeGreaterThanOrEqual(3);
|
||||
|
||||
const toolResultMsg = agent.state.messages.find((m) => m.role === "toolResult");
|
||||
expect(toolResultMsg).toBeDefined();
|
||||
if (toolResultMsg?.role !== "toolResult") throw new Error("Expected tool result message");
|
||||
const textContent =
|
||||
toolResultMsg.content
|
||||
?.filter((c) => c.type === "text")
|
||||
.map((c: any) => c.text)
|
||||
.join("\n") || "";
|
||||
expect(textContent).toBeDefined();
|
||||
|
||||
const expectedResult = 123 * 456;
|
||||
expect(textContent).toContain(String(expectedResult));
|
||||
|
||||
const finalMessage = agent.state.messages[agent.state.messages.length - 1];
|
||||
if (finalMessage.role !== "assistant") throw new Error("Expected final assistant message");
|
||||
const finalText = finalMessage.content.find((c) => c.type === "text");
|
||||
expect(finalText).toBeDefined();
|
||||
if (finalText?.type !== "text") throw new Error("Expected text content");
|
||||
// Check for number with or without comma formatting
|
||||
const hasNumber =
|
||||
finalText.text.includes(String(expectedResult)) ||
|
||||
finalText.text.includes("56,088") ||
|
||||
finalText.text.includes("56088");
|
||||
expect(hasNumber).toBe(true);
|
||||
}
|
||||
|
||||
async function abortExecution(model: Model<any>) {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [calculateTool],
|
||||
},
|
||||
});
|
||||
|
||||
const promptPromise = agent.prompt("Calculate 100 * 200, then 300 * 400, then sum the results.");
|
||||
|
||||
setTimeout(() => {
|
||||
agent.abort();
|
||||
}, 100);
|
||||
|
||||
await promptPromise;
|
||||
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
expect(agent.state.messages.length).toBeGreaterThanOrEqual(2);
|
||||
|
||||
const lastMessage = agent.state.messages[agent.state.messages.length - 1];
|
||||
if (lastMessage.role !== "assistant") throw new Error("Expected assistant message");
|
||||
expect(lastMessage.stopReason).toBe("aborted");
|
||||
expect(lastMessage.errorMessage).toBeDefined();
|
||||
expect(agent.state.error).toBeDefined();
|
||||
expect(agent.state.error).toBe(lastMessage.errorMessage);
|
||||
}
|
||||
|
||||
async function stateUpdates(model: Model<any>) {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
});
|
||||
|
||||
const events: Array<string> = [];
|
||||
|
||||
agent.subscribe((event) => {
|
||||
events.push(event.type);
|
||||
});
|
||||
|
||||
await agent.prompt("Count from 1 to 5.");
|
||||
|
||||
// Should have received lifecycle events
|
||||
expect(events).toContain("agent_start");
|
||||
expect(events).toContain("agent_end");
|
||||
expect(events).toContain("message_start");
|
||||
expect(events).toContain("message_end");
|
||||
// May have message_update events during streaming
|
||||
const hasMessageUpdates = events.some((e) => e === "message_update");
|
||||
expect(hasMessageUpdates).toBe(true);
|
||||
|
||||
// Check final state
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
expect(agent.state.messages.length).toBe(2); // User message + assistant response
|
||||
}
|
||||
|
||||
async function multiTurnConversation(model: Model<any>) {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
});
|
||||
|
||||
await agent.prompt("My name is Alice.");
|
||||
expect(agent.state.messages.length).toBe(2);
|
||||
|
||||
await agent.prompt("What is my name?");
|
||||
expect(agent.state.messages.length).toBe(4);
|
||||
|
||||
const lastMessage = agent.state.messages[3];
|
||||
if (lastMessage.role !== "assistant") throw new Error("Expected assistant message");
|
||||
const lastText = lastMessage.content.find((c) => c.type === "text");
|
||||
if (lastText?.type !== "text") throw new Error("Expected text content");
|
||||
expect(lastText.text.toLowerCase()).toContain("alice");
|
||||
}
|
||||
|
||||
describe("Agent E2E Tests", () => {
|
||||
describe.skipIf(!process.env.GEMINI_API_KEY)("Google Provider (gemini-2.5-flash)", () => {
|
||||
const model = getModel("google", "gemini-2.5-flash");
|
||||
|
||||
it("should handle basic text prompt", async () => {
|
||||
await basicPrompt(model);
|
||||
});
|
||||
|
||||
it("should execute tools correctly", async () => {
|
||||
await toolExecution(model);
|
||||
});
|
||||
|
||||
it("should handle abort during execution", async () => {
|
||||
await abortExecution(model);
|
||||
});
|
||||
|
||||
it("should emit state updates during streaming", async () => {
|
||||
await stateUpdates(model);
|
||||
});
|
||||
|
||||
it("should maintain context across multiple turns", async () => {
|
||||
await multiTurnConversation(model);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("OpenAI Provider (gpt-4o-mini)", () => {
|
||||
const model = getModel("openai", "gpt-4o-mini");
|
||||
|
||||
it("should handle basic text prompt", async () => {
|
||||
await basicPrompt(model);
|
||||
});
|
||||
|
||||
it("should execute tools correctly", async () => {
|
||||
await toolExecution(model);
|
||||
});
|
||||
|
||||
it("should handle abort during execution", async () => {
|
||||
await abortExecution(model);
|
||||
});
|
||||
|
||||
it("should emit state updates during streaming", async () => {
|
||||
await stateUpdates(model);
|
||||
});
|
||||
|
||||
it("should maintain context across multiple turns", async () => {
|
||||
await multiTurnConversation(model);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.ANTHROPIC_API_KEY)("Anthropic Provider (claude-haiku-4-5)", () => {
|
||||
const model = getModel("anthropic", "claude-haiku-4-5");
|
||||
|
||||
it("should handle basic text prompt", async () => {
|
||||
await basicPrompt(model);
|
||||
});
|
||||
|
||||
it("should execute tools correctly", async () => {
|
||||
await toolExecution(model);
|
||||
});
|
||||
|
||||
it("should handle abort during execution", async () => {
|
||||
await abortExecution(model);
|
||||
});
|
||||
|
||||
it("should emit state updates during streaming", async () => {
|
||||
await stateUpdates(model);
|
||||
});
|
||||
|
||||
it("should maintain context across multiple turns", async () => {
|
||||
await multiTurnConversation(model);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.XAI_API_KEY)("xAI Provider (grok-3)", () => {
|
||||
const model = getModel("xai", "grok-3");
|
||||
|
||||
it("should handle basic text prompt", async () => {
|
||||
await basicPrompt(model);
|
||||
});
|
||||
|
||||
it("should execute tools correctly", async () => {
|
||||
await toolExecution(model);
|
||||
});
|
||||
|
||||
it("should handle abort during execution", async () => {
|
||||
await abortExecution(model);
|
||||
});
|
||||
|
||||
it("should emit state updates during streaming", async () => {
|
||||
await stateUpdates(model);
|
||||
});
|
||||
|
||||
it("should maintain context across multiple turns", async () => {
|
||||
await multiTurnConversation(model);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.GROQ_API_KEY)("Groq Provider (openai/gpt-oss-20b)", () => {
|
||||
const model = getModel("groq", "openai/gpt-oss-20b");
|
||||
|
||||
it("should handle basic text prompt", async () => {
|
||||
await basicPrompt(model);
|
||||
});
|
||||
|
||||
it("should execute tools correctly", async () => {
|
||||
await toolExecution(model);
|
||||
});
|
||||
|
||||
it("should handle abort during execution", async () => {
|
||||
await abortExecution(model);
|
||||
});
|
||||
|
||||
it("should emit state updates during streaming", async () => {
|
||||
await stateUpdates(model);
|
||||
});
|
||||
|
||||
it("should maintain context across multiple turns", async () => {
|
||||
await multiTurnConversation(model);
|
||||
});
|
||||
});
|
||||
|
||||
/*describe.skipIf(!process.env.CEREBRAS_API_KEY)("Cerebras Provider (gpt-oss-120b)", () => {
|
||||
const model = getModel("cerebras", "gpt-oss-120b");
|
||||
|
||||
it("should handle basic text prompt", async () => {
|
||||
await basicPrompt(model);
|
||||
});
|
||||
|
||||
it("should execute tools correctly", async () => {
|
||||
await toolExecution(model);
|
||||
});
|
||||
|
||||
it("should handle abort during execution", async () => {
|
||||
await abortExecution(model);
|
||||
});
|
||||
|
||||
it("should emit state updates during streaming", async () => {
|
||||
await stateUpdates(model);
|
||||
});
|
||||
|
||||
it("should maintain context across multiple turns", async () => {
|
||||
await multiTurnConversation(model);
|
||||
});
|
||||
});*/
|
||||
|
||||
describe.skipIf(!process.env.ZAI_API_KEY)("zAI Provider (glm-4.5-air)", () => {
|
||||
const model = getModel("zai", "glm-4.5-air");
|
||||
|
||||
it("should handle basic text prompt", async () => {
|
||||
await basicPrompt(model);
|
||||
});
|
||||
|
||||
it("should execute tools correctly", async () => {
|
||||
await toolExecution(model);
|
||||
});
|
||||
|
||||
it("should handle abort during execution", async () => {
|
||||
await abortExecution(model);
|
||||
});
|
||||
|
||||
it("should emit state updates during streaming", async () => {
|
||||
await stateUpdates(model);
|
||||
});
|
||||
|
||||
it("should maintain context across multiple turns", async () => {
|
||||
await multiTurnConversation(model);
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!hasBedrockCredentials())("Amazon Bedrock Provider (claude-sonnet-4-5)", () => {
|
||||
const model = getModel("amazon-bedrock", "global.anthropic.claude-sonnet-4-5-20250929-v1:0");
|
||||
|
||||
it("should handle basic text prompt", async () => {
|
||||
await basicPrompt(model);
|
||||
});
|
||||
|
||||
it("should execute tools correctly", async () => {
|
||||
await toolExecution(model);
|
||||
});
|
||||
|
||||
it("should handle abort during execution", async () => {
|
||||
await abortExecution(model);
|
||||
});
|
||||
|
||||
it("should emit state updates during streaming", async () => {
|
||||
await stateUpdates(model);
|
||||
});
|
||||
|
||||
it("should maintain context across multiple turns", async () => {
|
||||
await multiTurnConversation(model);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Agent.continue()", () => {
|
||||
describe("validation", () => {
|
||||
it("should throw when no messages in context", async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "Test",
|
||||
model: getModel("openai", "gpt-5.4"),
|
||||
},
|
||||
});
|
||||
|
||||
await expect(agent.continue()).rejects.toThrow("No messages to continue from");
|
||||
});
|
||||
|
||||
it("should throw when last message is assistant", async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "Test",
|
||||
model: getModel("openai", "gpt-5.4"),
|
||||
},
|
||||
});
|
||||
|
||||
const assistantMessage: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "Hello" }],
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
model: "gpt-5.4",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
agent.replaceMessages([assistantMessage]);
|
||||
|
||||
await expect(agent.continue()).rejects.toThrow("Cannot continue from message role: assistant");
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("continue from user message", () => {
|
||||
const model = getModel("openai", "gpt-5.4");
|
||||
|
||||
it("should continue and get response when last message is user", async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "You are a helpful assistant. Follow instructions exactly.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
},
|
||||
});
|
||||
|
||||
// Manually add a user message without calling prompt()
|
||||
const userMessage: UserMessage = {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "Say exactly: HELLO WORLD" }],
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
agent.replaceMessages([userMessage]);
|
||||
|
||||
// Continue from the user message
|
||||
await agent.continue();
|
||||
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
expect(agent.state.messages.length).toBe(2);
|
||||
expect(agent.state.messages[0].role).toBe("user");
|
||||
expect(agent.state.messages[1].role).toBe("assistant");
|
||||
|
||||
const assistantMsg = agent.state.messages[1] as AssistantMessage;
|
||||
const textContent = assistantMsg.content.find((c) => c.type === "text");
|
||||
expect(textContent).toBeDefined();
|
||||
if (textContent?.type === "text") {
|
||||
expect(textContent.text.toUpperCase()).toContain("HELLO WORLD");
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe.skipIf(!process.env.OPENAI_API_KEY)("continue from tool result", () => {
|
||||
const model = getModel("openai", "gpt-5.4");
|
||||
|
||||
it("should continue and process tool results", async () => {
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt:
|
||||
"You are a helpful assistant. After getting a calculation result, state the answer clearly.",
|
||||
model,
|
||||
thinkingLevel: "off",
|
||||
tools: [calculateTool],
|
||||
},
|
||||
});
|
||||
|
||||
// Set up a conversation state as if tool was just executed
|
||||
const userMessage: UserMessage = {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "What is 5 + 3?" }],
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const assistantMessage: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{ type: "text", text: "Let me calculate that." },
|
||||
{ type: "toolCall", id: "calc-1", name: "calculate", arguments: { expression: "5 + 3" } },
|
||||
],
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
model: "claude-haiku-4-5",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "toolUse",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const toolResult: ToolResultMessage = {
|
||||
role: "toolResult",
|
||||
toolCallId: "calc-1",
|
||||
toolName: "calculate",
|
||||
content: [{ type: "text", text: "5 + 3 = 8" }],
|
||||
isError: false,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
agent.replaceMessages([userMessage, assistantMessage, toolResult]);
|
||||
|
||||
// Continue from the tool result
|
||||
await agent.continue();
|
||||
|
||||
expect(agent.state.isStreaming).toBe(false);
|
||||
// Should have added an assistant response
|
||||
expect(agent.state.messages.length).toBeGreaterThanOrEqual(4);
|
||||
|
||||
const lastMessage = agent.state.messages[agent.state.messages.length - 1];
|
||||
expect(lastMessage.role).toBe("assistant");
|
||||
|
||||
if (lastMessage.role === "assistant") {
|
||||
const textContent = lastMessage.content
|
||||
.filter((c) => c.type === "text")
|
||||
.map((c) => (c as { type: "text"; text: string }).text)
|
||||
.join(" ");
|
||||
// Should mention 8 in the response
|
||||
expect(textContent).toMatch(/8/);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
32
packages/agent/test/utils/calculate.ts
Normal file
32
packages/agent/test/utils/calculate.ts
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import type { AgentTool, AgentToolResult } from "../../src/types.js";
|
||||
|
||||
export interface CalculateResult extends AgentToolResult<undefined> {
|
||||
content: Array<{ type: "text"; text: string }>;
|
||||
details: undefined;
|
||||
}
|
||||
|
||||
export function calculate(expression: string): CalculateResult {
|
||||
try {
|
||||
const result = new Function(`return ${expression}`)();
|
||||
return { content: [{ type: "text", text: `${expression} = ${result}` }], details: undefined };
|
||||
} catch (e: any) {
|
||||
throw new Error(e.message || String(e));
|
||||
}
|
||||
}
|
||||
|
||||
const calculateSchema = Type.Object({
|
||||
expression: Type.String({ description: "The mathematical expression to evaluate" }),
|
||||
});
|
||||
|
||||
type CalculateParams = Static<typeof calculateSchema>;
|
||||
|
||||
export const calculateTool: AgentTool<typeof calculateSchema, undefined> = {
|
||||
label: "Calculator",
|
||||
name: "calculate",
|
||||
description: "Evaluate mathematical expressions",
|
||||
parameters: calculateSchema,
|
||||
execute: async (_toolCallId: string, args: CalculateParams) => {
|
||||
return calculate(args.expression);
|
||||
},
|
||||
};
|
||||
46
packages/agent/test/utils/get-current-time.ts
Normal file
46
packages/agent/test/utils/get-current-time.ts
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import type { AgentTool, AgentToolResult } from "../../src/types.js";
|
||||
|
||||
export interface GetCurrentTimeResult extends AgentToolResult<{ utcTimestamp: number }> {}
|
||||
|
||||
export async function getCurrentTime(timezone?: string): Promise<GetCurrentTimeResult> {
|
||||
const date = new Date();
|
||||
if (timezone) {
|
||||
try {
|
||||
const timeStr = date.toLocaleString("en-US", {
|
||||
timeZone: timezone,
|
||||
dateStyle: "full",
|
||||
timeStyle: "long",
|
||||
});
|
||||
return {
|
||||
content: [{ type: "text", text: timeStr }],
|
||||
details: { utcTimestamp: date.getTime() },
|
||||
};
|
||||
} catch (_e) {
|
||||
throw new Error(`Invalid timezone: ${timezone}. Current UTC time: ${date.toISOString()}`);
|
||||
}
|
||||
}
|
||||
const timeStr = date.toLocaleString("en-US", { dateStyle: "full", timeStyle: "long" });
|
||||
return {
|
||||
content: [{ type: "text", text: timeStr }],
|
||||
details: { utcTimestamp: date.getTime() },
|
||||
};
|
||||
}
|
||||
|
||||
const getCurrentTimeSchema = Type.Object({
|
||||
timezone: Type.Optional(
|
||||
Type.String({ description: "Optional timezone (e.g., 'America/New_York', 'Europe/London')" }),
|
||||
),
|
||||
});
|
||||
|
||||
type GetCurrentTimeParams = Static<typeof getCurrentTimeSchema>;
|
||||
|
||||
export const getCurrentTimeTool: AgentTool<typeof getCurrentTimeSchema, { utcTimestamp: number }> = {
|
||||
label: "Current Time",
|
||||
name: "get_current_time",
|
||||
description: "Get the current date and time",
|
||||
parameters: getCurrentTimeSchema,
|
||||
execute: async (_toolCallId: string, args: GetCurrentTimeParams) => {
|
||||
return getCurrentTime(args.timezone);
|
||||
},
|
||||
};
|
||||
9
packages/agent/tsconfig.build.json
Normal file
9
packages/agent/tsconfig.build.json
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"extends": "../../tsconfig.base.json",
|
||||
"compilerOptions": {
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src"
|
||||
},
|
||||
"include": ["src/**/*.ts"],
|
||||
"exclude": ["node_modules", "dist", "**/*.d.ts", "src/**/*.d.ts"]
|
||||
}
|
||||
9
packages/agent/vitest.config.ts
Normal file
9
packages/agent/vitest.config.ts
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
import { defineConfig } from "vitest/config";
|
||||
|
||||
export default defineConfig({
|
||||
test: {
|
||||
globals: true,
|
||||
environment: "node",
|
||||
testTimeout: 30000, // 30 seconds for API calls
|
||||
},
|
||||
});
|
||||
909
packages/ai/CHANGELOG.md
Normal file
909
packages/ai/CHANGELOG.md
Normal file
|
|
@ -0,0 +1,909 @@
|
|||
# Changelog
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.62.0] - 2026-03-23
|
||||
|
||||
### Added
|
||||
|
||||
- Added `requestMetadata` option to `BedrockOptions` for AWS cost allocation tagging; key-value pairs are forwarded to the Bedrock Converse API `requestMetadata` field and appear in AWS Cost Explorer split cost allocation data ([#2511](https://github.com/jae/agent-jae-cli/pull/2511) by [@wjonaskr](https://github.com/wjonaskr))
|
||||
- Exported `BedrockOptions` type from the package root entry point, consistent with other provider option types.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed OpenAI Responses replay for foreign tool-call item IDs by hashing foreign `function_call.id` values into bounded `fc_<hash>` IDs instead of preserving backend-specific normalized shapes that OpenAI Codex rejects.
|
||||
- Fixed Anthropic thinking disable handling to send `thinking: { type: "disabled" }` for reasoning-capable models when thinking is explicitly off, and added payload and env-gated end-to-end coverage for the Anthropic provider ([#2022](https://github.com/jae/agent-jae-cli/issues/2022))
|
||||
- Fixed explicit thinking disable handling across Google, Google Vertex, Gemini CLI, OpenAI Responses, Azure OpenAI Responses, and OpenRouter-backed OpenAI-compatible completions. Gemini 3 models now fall back to the lowest supported thinking level when full disable is not supported, and OpenAI/OpenRouter reasoning models now send explicit `none` effort instead of relying on provider defaults ([#2490](https://github.com/jae/agent-jae-cli/issues/2490))
|
||||
- Fixed OpenAI-compatible completions streams to ignore null chunks instead of crashing ([#2466](https://github.com/jae/agent-jae-cli/pull/2466) by [@Cheng-Zi-Qing](https://github.com/Cheng-Zi-Qing))
|
||||
|
||||
## [0.61.1] - 2026-03-20
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed MiniMax model metadata to add missing `MiniMax-M2.1-highspeed` entries for the `minimax` and `minimax-cn` providers and normalize MiniMax Anthropic-compatible context limits to the provider's supported model set ([#2445](https://github.com/jae/agent-jae-cli/pull/2445) by [@1500256797](https://github.com/1500256797))
|
||||
|
||||
## [0.61.0] - 2026-03-20
|
||||
|
||||
### Added
|
||||
|
||||
- Added `gpt-5.4-mini` model support for the `openai-codex` provider with Codex pricing metadata and unit coverage ([#2334](https://github.com/jae/agent-jae-cli/pull/2334) by [@justram](https://github.com/justram))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed `validateToolArguments()` to fall back gracefully when AJV schema compilation is blocked in restricted runtimes such as Cloudflare Workers, allowing tool execution to proceed without schema validation ([#2395](https://github.com/jae/agent-jae-cli/issues/2395))
|
||||
- Fixed `google-vertex` API key resolution to ignore placeholder auth markers like `<authenticated>` and fall back to ADC instead of sending them as literal API keys ([#2335](https://github.com/jae/agent-jae-cli/issues/2335))
|
||||
- Fixed OpenRouter reasoning requests to use the provider's nested `reasoning.effort` payload instead of OpenAI's `reasoning_effort`, restoring thinking level support for OpenRouter models ([#2298](https://github.com/jae/agent-jae-cli/pull/2298) by [@PriNova](https://github.com/PriNova))
|
||||
- Fixed Bedrock prompt caching for application inference profiles by allowing cache points to be forced with `AWS_BEDROCK_FORCE_CACHE=1` when the profile ARN does not expose the underlying Claude model name ([#2346](https://github.com/jae/agent-jae-cli/pull/2346) by [@haoqixu](https://github.com/haoqixu))
|
||||
|
||||
## [0.60.0] - 2026-03-18
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed Gemini 3 and Antigravity image tool results to stay inline as multimodal tool responses instead of being rerouted through separate follow-up messages ([#2052](https://github.com/jae/agent-jae-cli/issues/2052))
|
||||
- Fixed Bedrock Claude 4.6 model metadata to use the correct 200K context window instead of 1M ([#2305](https://github.com/jae/agent-jae-cli/issues/2305))
|
||||
- Fixed lazy built-in provider registration so compiled Bun binaries can still load providers on first use without eagerly bundling provider SDKs ([#2314](https://github.com/jae/agent-jae-cli/issues/2314))
|
||||
- Fixed built-in OAuth callback flows to share aligned callback handling across Anthropic, Gemini CLI, Antigravity, and OpenAI Codex, and fixed OpenAI Codex login to resolve immediately after callback completion ([#2316](https://github.com/jae/agent-jae-cli/issues/2316))
|
||||
- Fixed OpenAI-compatible z.ai `network_error` responses to surface as errors so callers can retry them instead of treating them as successful assistant messages ([#2313](https://github.com/jae/agent-jae-cli/issues/2313))
|
||||
- Fixed OpenAI Responses replay to normalize oversized resumed tool call IDs before sending them back to Codex and other Responses-compatible targets ([#2328](https://github.com/jae/agent-jae-cli/issues/2328))
|
||||
|
||||
## [0.59.0] - 2026-03-17
|
||||
|
||||
### Added
|
||||
|
||||
- Added `client` injection support to `AnthropicOptions`, allowing callers to provide a pre-built Anthropic-compatible client instead of constructing one internally.
|
||||
|
||||
### Changed
|
||||
|
||||
- Lazy-load built-in provider modules and root provider wrappers so importing `@jaeswift/jae-ai` no longer eagerly loads provider SDKs, significantly reducing base startup cost without changing dependency installation footprint ([#2297](https://github.com/jae/agent-jae-cli/issues/2297))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Added provider-specific `responseId` support on `AssistantMessage` for providers that expose upstream response or message identifiers, including Anthropic, OpenAI, Google, Gemini CLI, and Mistral, and added end-to-end coverage for supported OAuth and API key providers ([#2245](https://github.com/jae/agent-jae-cli/issues/2245))
|
||||
- Fixed Claude 4.6 context window overrides in generated model metadata so build-time catalogs reflect the intended values ([#2286](https://github.com/jae/agent-jae-cli/issues/2286))
|
||||
|
||||
## [0.58.4] - 2026-03-16
|
||||
|
||||
## [0.58.3] - 2026-03-15
|
||||
|
||||
## [0.58.2] - 2026-03-15
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed Anthropic OAuth manual login and token refresh by using the localhost callback URI for pasted redirect/code flows and omitting `scope` from refresh-token requests ([#2169](https://github.com/jae/agent-jae-cli/issues/2169))
|
||||
|
||||
## [0.58.1] - 2026-03-14
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed OpenAI Codex websocket protocol to include required headers and properly terminate SSE streams on connection close ([#1961](https://github.com/jae/agent-jae-cli/issues/1961))
|
||||
- Fixed Bedrock prompt caching being enabled for non-Claude models, causing API errors ([#2053](https://github.com/jae/agent-jae-cli/issues/2053))
|
||||
- Fixed Qwen models via OpenAI-compatible providers by adding `qwen-chat-template` compat mode that uses Qwen's native chat template format ([#2020](https://github.com/jae/agent-jae-cli/issues/2020))
|
||||
- Fixed Bedrock unsigned thinking replay to handle edge cases with empty or malformed thinking blocks ([#2063](https://github.com/jae/agent-jae-cli/issues/2063))
|
||||
- Fixed xhigh reasoning effort detection for Claude Opus 4.6 to match by model ID instead of requiring explicit capability flag ([#2040](https://github.com/jae/agent-jae-cli/issues/2040))
|
||||
- Handle `finish_reason: "end"` from Ollama/LM Studio by mapping it to `"stop"` instead of throwing ([#2142](https://github.com/jae/agent-jae-cli/issues/2142))
|
||||
|
||||
## [0.58.0] - 2026-03-14
|
||||
|
||||
### Added
|
||||
|
||||
- Added `GOOGLE_CLOUD_API_KEY` environment variable support for the `google-vertex` provider as an alternative to Application Default Credentials ([#1976](https://github.com/jae/agent-jae-cli/pull/1976) by [@gordonhwc](https://github.com/gordonhwc))
|
||||
|
||||
### Changed
|
||||
|
||||
- Raised Claude Opus 4.6, Sonnet 4.6, and related Bedrock model context windows from 200K to 1M tokens ([#2135](https://github.com/jae/agent-jae-cli/pull/2135) by [@mitsuhiko](https://github.com/mitsuhiko))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed GitHub Copilot device-code login polling to respect OAuth slow-down intervals, wait before the first token poll, and include a clearer clock-drift hint in WSL/VM environments when repeated slow-downs lead to timeout.
|
||||
- Fixed usage statistics not being captured for OpenAI-compatible providers that return usage in `choice.usage` instead of the standard `chunk.usage` (e.g., Moonshot/Kimi) ([#2017](https://github.com/jae/agent-jae-cli/issues/2017))
|
||||
- Fixed tool result images not being sent in `function_call_output` items for OpenAI Responses API providers, causing image data to be silently dropped in tool results ([#2104](https://github.com/jae/agent-jae-cli/issues/2104))
|
||||
- Fixed assistant content being sent as structured content blocks instead of plain strings in the `openai-completions` provider, causing errors with some OpenAI-compatible backends ([#2008](https://github.com/jae/agent-jae-cli/pull/2008) by [@geraldoaax](https://github.com/geraldoaax))
|
||||
- Fixed error details in OpenAI Responses `response.failed` handler to include status code, error code, and message instead of a generic failure ([#1956](https://github.com/jae/agent-jae-cli/pull/1956) by [@drewburr](https://github.com/drewburr))
|
||||
|
||||
## [0.57.1] - 2026-03-07
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed context overflow detection to recognize z.ai `model_context_window_exceeded` errors surfaced through OpenAI-compatible stop reason handling ([#1937](https://github.com/jae/agent-jae-cli/issues/1937))
|
||||
|
||||
## [0.57.0] - 2026-03-07
|
||||
|
||||
### Added
|
||||
|
||||
- Added per-request payload inspection and replacement hook support via `beforeProviderRequest`, allowing callers to inspect or replace provider payloads before sending.
|
||||
|
||||
## [0.56.3] - 2026-03-06
|
||||
|
||||
### Added
|
||||
|
||||
- Added `claude-sonnet-4-6` model for the `google-antigravity` provider ([#1859](https://github.com/jae/agent-jae-cli/issues/1859)).
|
||||
- Bumped default Antigravity User-Agent version to `1.18.4` ([#1859](https://github.com/jae/agent-jae-cli/issues/1859)).
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed Antigravity Claude thinking beta header detection to use provider and model capability instead of `-thinking` suffix, so models like `claude-sonnet-4-6` receive the header correctly ([#1859](https://github.com/jae/agent-jae-cli/issues/1859)).
|
||||
- Fixed OpenAI Responses reasoning replay regression that dropped reasoning blocks on follow-up turns ([#1878](https://github.com/jae/agent-jae-cli/issues/1878))
|
||||
|
||||
## [0.56.2] - 2026-03-05
|
||||
|
||||
### Added
|
||||
|
||||
- Added `gpt-5.4` model support for `openai`, `openai-codex`, `azure-openai-responses`, and `opencode` providers, with GPT-5.4 treated as xhigh-capable and capped to a 272000 context window in built-in metadata.
|
||||
- Added `gpt-5.3-codex` fallback model availability for `github-copilot` until upstream model catalogs include it ([#1853](https://github.com/jae/agent-jae-cli/issues/1853)).
|
||||
|
||||
### Fixed
|
||||
|
||||
- Preserved OpenAI Responses assistant `phase` metadata (`commentary`, `final_answer`) across turns by encoding `id` and `phase` in `textSignature` for session persistence and replay, with backward compatibility for legacy plain signatures ([#1819](https://github.com/jae/agent-jae-cli/issues/1819)).
|
||||
- Fixed OpenAI Responses replay to omit empty thinking blocks, avoiding invalid no-op reasoning items in follow-up turns.
|
||||
- Switched the Mistral provider from the OpenAI-compatible completions path to Mistral's native SDK and conversations API, preserving native thinking blocks and Mistral-specific message semantics across turns ([#1716](https://github.com/jae/agent-jae-cli/issues/1716)).
|
||||
- Fixed Antigravity endpoint fallback: 403/404 responses now cascade to the next endpoint instead of throwing immediately, added `autopush-cloudcode-pa.sandbox` endpoint to the fallback list, and removed extra fingerprint headers (`X-Goog-Api-Client`, `Client-Metadata`) from Antigravity requests ([#1830](https://github.com/jae/agent-jae-cli/issues/1830)).
|
||||
- Fixed `@jaeswift/jae-ai/oauth` package exports to point directly at built `dist` files, avoiding broken TypeScript resolution through unpublished wrapper targets ([#1856](https://github.com/jae/agent-jae-cli/issues/1856)).
|
||||
- Fixed Gemini 3 unsigned tool call replay: use `skip_thought_signature_validator` sentinel instead of converting function calls to text, preserving structured tool call context across multi-turn conversations ([#1829](https://github.com/jae/agent-jae-cli/issues/1829)).
|
||||
|
||||
## [0.56.1] - 2026-03-05
|
||||
|
||||
## [0.56.0] - 2026-03-04
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- Moved Node OAuth runtime exports off the top-level package entry. Import OAuth login/refresh functions from `@jaeswift/jae-ai/oauth` instead of `@jaeswift/jae-ai` ([#1814](https://github.com/jae/agent-jae-cli/issues/1814))
|
||||
|
||||
### Added
|
||||
|
||||
- Added `gemini-3.1-flash-lite-preview` fallback model entry for the `google` provider so it remains selectable until upstream model catalogs include it ([#1785](https://github.com/jae/agent-jae-cli/issues/1785), thanks [@n-WN](https://github.com/n-WN)).
|
||||
- Added OpenCode Go provider support with `opencode-go` model catalog entries and `OPENCODE_API_KEY` environment variable support ([#1757](https://github.com/jae/agent-jae-cli/issues/1757)).
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated Antigravity Gemini 3.1 model metadata and request headers to match current upstream behavior.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed Gemini 3.1 thinking-level detection in `google` and `google-vertex` providers so `gemini-3.1-*` models use Gemini 3 level-based thinking config instead of budget fallback ([#1785](https://github.com/jae/agent-jae-cli/issues/1785), thanks [@n-WN](https://github.com/n-WN)).
|
||||
- Fixed browser bundling failures by lazy-loading the Bedrock provider and removing Node-only side effects from the default browser import graph ([#1814](https://github.com/jae/agent-jae-cli/issues/1814)).
|
||||
- Fixed `ERR_VM_DYNAMIC_IMPORT_CALLBACK_MISSING` failures by replacing `Function`-based dynamic imports with module dynamic imports in browser-safe provider loading paths ([#1814](https://github.com/jae/agent-jae-cli/issues/1814)).
|
||||
- Fixed Bedrock region resolution for `AWS_PROFILE` by honoring `region` from the selected profile when present ([#1800](https://github.com/jae/agent-jae-cli/issues/1800)).
|
||||
- Fixed Groq Qwen3 reasoning effort mapping by translating unsupported effort values to provider-supported values ([#1745](https://github.com/jae/agent-jae-cli/issues/1745)).
|
||||
|
||||
## [0.55.4] - 2026-03-02
|
||||
|
||||
## [0.55.3] - 2026-02-27
|
||||
|
||||
## [0.55.2] - 2026-02-27
|
||||
|
||||
### Fixed
|
||||
|
||||
- Restored built-in OAuth providers when unregistering dynamically registered provider IDs and added `resetOAuthProviders()` for registry reset flows.
|
||||
- Fixed Z.ai thinking control using wrong parameter name (`thinking` instead of `enable_thinking`), causing thinking to always be enabled and wasting tokens/latency ([#1674](https://github.com/jae/agent-jae-cli/pull/1674) by [@okuyam2y](https://github.com/okuyam2y))
|
||||
- Fixed `redacted_thinking` blocks being silently dropped during Anthropic streaming. They are now captured as `ThinkingContent` with `redacted: true`, passed back to the API in multi-turn conversations, and handled in cross-model message transformation ([#1665](https://github.com/jae/agent-jae-cli/pull/1665) by [@tctev](https://github.com/tctev))
|
||||
- Fixed `interleaved-thinking-2025-05-14` beta header being sent for adaptive thinking models (Opus 4.6, Sonnet 4.6) where the header is deprecated or redundant ([#1665](https://github.com/jae/agent-jae-cli/pull/1665) by [@tctev](https://github.com/tctev))
|
||||
- Fixed temperature being sent alongside extended thinking, which is incompatible with both adaptive and budget-based thinking modes ([#1665](https://github.com/jae/agent-jae-cli/pull/1665) by [@tctev](https://github.com/tctev))
|
||||
- Fixed `(external, cli)` user-agent flag causing 401 errors on Anthropic setup-token endpoint ([#1677](https://github.com/jae/agent-jae-cli/pull/1677) by [@LazerLance777](https://github.com/LazerLance777))
|
||||
- Fixed crash when OpenAI-compatible provider returns a chunk with no `choices` array by adding optional chaining ([#1671](https://github.com/jae/agent-jae-cli/issues/1671))
|
||||
|
||||
## [0.55.1] - 2026-02-26
|
||||
|
||||
### Added
|
||||
|
||||
- Added `gemini-3.1-pro-preview` model support to the `google-gemini-cli` provider ([#1599](https://github.com/jae/agent-jae-cli/pull/1599) by [@audichuang](https://github.com/audichuang))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed adaptive thinking for Claude Sonnet 4.6 in Anthropic and Bedrock providers, and clamped unsupported `xhigh` effort values to supported levels ([#1548](https://github.com/jae/agent-jae-cli/pull/1548) by [@tctev](https://github.com/tctev))
|
||||
- Fixed Vertex ADC credential detection race by avoiding caching a false negative during async import initialization ([#1550](https://github.com/jae/agent-jae-cli/pull/1550) by [@jeremiahgaylord-web](https://github.com/jeremiahgaylord-web))
|
||||
|
||||
## [0.55.0] - 2026-02-24
|
||||
|
||||
## [0.54.2] - 2026-02-23
|
||||
|
||||
## [0.54.1] - 2026-02-22
|
||||
|
||||
## [0.54.0] - 2026-02-19
|
||||
|
||||
## [0.53.1] - 2026-02-19
|
||||
|
||||
## [0.53.0] - 2026-02-17
|
||||
|
||||
### Added
|
||||
|
||||
- Added Anthropic `claude-sonnet-4-6` fallback model entry to generated model definitions.
|
||||
|
||||
## [0.52.12] - 2026-02-13
|
||||
|
||||
### Added
|
||||
|
||||
- Added `transport` to `StreamOptions` with values `"sse"`, `"websocket"`, and `"auto"` (currently supported by `openai-codex-responses`).
|
||||
- Added WebSocket transport support for OpenAI Codex Responses (`openai-codex-responses`).
|
||||
|
||||
### Changed
|
||||
|
||||
- OpenAI Codex Responses now defaults to SSE transport unless `transport` is explicitly set.
|
||||
- OpenAI Codex Responses WebSocket connections are cached per `sessionId` and expire after 5 minutes of inactivity.
|
||||
|
||||
## [0.52.11] - 2026-02-13
|
||||
|
||||
### Added
|
||||
|
||||
- Added MiniMax M2.5 model entries for `minimax`, `minimax-cn`, `openrouter`, and `vercel-ai-gateway` providers, plus `minimax-m2.5-free` for `opencode`.
|
||||
|
||||
## [0.52.10] - 2026-02-12
|
||||
|
||||
### Added
|
||||
|
||||
- Added optional `metadata` field to `StreamOptions` for passing provider-specific metadata (e.g. Anthropic `user_id` for abuse tracking/rate limiting) ([#1384](https://github.com/jae/agent-jae-cli/pull/1384) by [@7Sageer](https://github.com/7Sageer))
|
||||
- Added `gpt-5.3-codex-spark` model definition for OpenAI and OpenAI Codex providers (128k context, text-only, research preview). Not yet functional, may become available in the next few hours or days.
|
||||
|
||||
### Changed
|
||||
|
||||
- Routed GitHub Copilot Claude 4.x models through Anthropic Messages API, centralized Copilot dynamic header handling, and added Copilot Claude Anthropic stream coverage ([#1353](https://github.com/jae/agent-jae-cli/pull/1353) by [@NateSmyth](https://github.com/NateSmyth))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed OpenAI completions and responses streams to tolerate malformed trailing tool-call JSON without failing parsing ([#1424](https://github.com/jae/agent-jae-cli/issues/1424))
|
||||
|
||||
## [0.52.9] - 2026-02-08
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated the Antigravity system instruction to a more compact version for Google Gemini CLI compatibility
|
||||
|
||||
### Fixed
|
||||
|
||||
- Use `parametersJsonSchema` for Google provider tool declarations to support full JSON Schema (anyOf, oneOf, const, etc.) ([#1398](https://github.com/jae/agent-jae-cli/issues/1398) by [@jarib](https://github.com/jarib))
|
||||
- Reverted incorrect Antigravity model change: `claude-opus-4-6-thinking` back to `claude-opus-4-5-thinking` (model doesn't exist on Antigravity endpoint)
|
||||
- Corrected opencode context windows for Claude Sonnet 4 and 4.5 ([#1383](https://github.com/jae/agent-jae-cli/issues/1383))
|
||||
|
||||
## [0.52.8] - 2026-02-07
|
||||
|
||||
### Added
|
||||
|
||||
- Added OpenRouter `auto` model alias for automatic model routing ([#1361](https://github.com/jae/agent-jae-cli/pull/1361) by [@yogasanas](https://github.com/yogasanas))
|
||||
|
||||
### Changed
|
||||
|
||||
- Replaced Claude Opus 4.5 with Opus 4.6 in model definitions ([#1345](https://github.com/jae/agent-jae-cli/pull/1345) by [@calvin-hpnet](https://github.com/calvin-hpnet))
|
||||
|
||||
## [0.52.7] - 2026-02-06
|
||||
|
||||
### Added
|
||||
|
||||
- Added `AWS_BEDROCK_SKIP_AUTH` and `AWS_BEDROCK_FORCE_HTTP1` environment variables for connecting to unauthenticated Bedrock proxies ([#1320](https://github.com/jae/agent-jae-cli/pull/1320) by [@virtuald](https://github.com/virtuald))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Set OpenAI Responses API requests to `store: false` by default to avoid server-side history logging ([#1308](https://github.com/jae/agent-jae-cli/issues/1308))
|
||||
- Re-exported TypeBox `Type`, `Static`, and `TSchema` from `@jaeswift/jae-ai` to match documentation and avoid duplicate TypeBox type identity issues in pnpm setups ([#1338](https://github.com/jae/agent-jae-cli/issues/1338))
|
||||
- Fixed Bedrock adaptive thinking handling for Claude Opus 4.6 with interleaved thinking beta responses ([#1323](https://github.com/jae/agent-jae-cli/pull/1323) by [@markusylisiurunen](https://github.com/markusylisiurunen))
|
||||
- Fixed `AWS_BEDROCK_SKIP_AUTH` environment detection to avoid `process` access in non-Node.js environments
|
||||
|
||||
## [0.52.6] - 2026-02-05
|
||||
|
||||
## [0.52.5] - 2026-02-05
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed `supportsXhigh()` to treat Anthropic Messages Opus 4.6 models as xhigh-capable so `streamSimple` can map `xhigh` to adaptive effort `max`
|
||||
|
||||
## [0.52.4] - 2026-02-05
|
||||
|
||||
## [0.52.3] - 2026-02-05
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed Bedrock Opus 4.6 model IDs (removed `:0` suffix) and cache pricing for `us.*` and `eu.*` variants
|
||||
- Added missing `eu.anthropic.claude-opus-4-6-v1` inference profile to model catalog
|
||||
- Fixed Claude Opus 4.6 context window metadata to 200000 for Anthropic and OpenCode providers
|
||||
|
||||
## [0.52.2] - 2026-02-05
|
||||
|
||||
## [0.52.1] - 2026-02-05
|
||||
|
||||
### Added
|
||||
|
||||
- Added adaptive thinking support for Claude Opus 4.6 with effort levels (`low`, `medium`, `high`, `max`)
|
||||
- Added `effort` option to `AnthropicOptions` for controlling adaptive thinking depth
|
||||
- `thinkingEnabled` now automatically uses adaptive thinking for Opus 4.6+ models and budget-based thinking for older models
|
||||
- `streamSimple`/`completeSimple` automatically map `ThinkingLevel` to effort levels for Opus 4.6
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated `@anthropic-ai/sdk` to 0.73.0
|
||||
- Updated `@aws-sdk/client-bedrock-runtime` to 3.983.0
|
||||
- Updated `@google/genai` to 1.40.0
|
||||
- Removed `fast-xml-parser` override (no longer needed)
|
||||
|
||||
## [0.52.0] - 2026-02-05
|
||||
|
||||
### Added
|
||||
|
||||
- Added Claude Opus 4.6 model to the generated model catalog
|
||||
- Added GPT-5.3 Codex model to the generated model catalog (OpenAI Codex provider only)
|
||||
|
||||
## [0.51.6] - 2026-02-04
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed OpenAI Codex Responses provider to respect configured baseUrl ([#1244](https://github.com/jae/agent-jae-cli/issues/1244))
|
||||
|
||||
## [0.51.5] - 2026-02-04
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed Bedrock model generation to drop legacy workarounds now handled upstream ([#1239](https://github.com/jae/agent-jae-cli/pull/1239) by [@unexge](https://github.com/unexge))
|
||||
|
||||
## [0.51.4] - 2026-02-03
|
||||
|
||||
## [0.51.3] - 2026-02-03
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed xhigh thinking level support check to accept gpt-5.2 model IDs ([#1209](https://github.com/jae/agent-jae-cli/issues/1209))
|
||||
|
||||
## [0.51.2] - 2026-02-03
|
||||
|
||||
## [0.51.1] - 2026-02-02
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed `cache_control` not being applied to string-format user messages in Anthropic provider
|
||||
|
||||
## [0.51.0] - 2026-02-01
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed `cacheRetention` option not being passed through in `buildBaseOptions` ([#1154](https://github.com/jae/agent-jae-cli/issues/1154))
|
||||
- Fixed OAuth login/refresh not using HTTP proxy settings (`HTTP_PROXY`, `HTTPS_PROXY` env vars) ([#1132](https://github.com/jae/agent-jae-cli/issues/1132))
|
||||
- Fixed OpenAI-compatible completions to omit unsupported `strict` tool fields for providers that reject them ([#1172](https://github.com/jae/agent-jae-cli/issues/1172))
|
||||
|
||||
## [0.50.9] - 2026-02-01
|
||||
|
||||
### Added
|
||||
|
||||
- Added `PI_AI_ANTIGRAVITY_VERSION` environment variable to override the Antigravity User-Agent version when Google updates their version requirements ([#1129](https://github.com/jae/agent-jae-cli/issues/1129))
|
||||
- Added `cacheRetention` stream option with provider-specific mappings for prompt cache controls, defaulting to short retention ([#1134](https://github.com/jae/agent-jae-cli/issues/1134))
|
||||
|
||||
## [0.50.8] - 2026-02-01
|
||||
|
||||
### Added
|
||||
|
||||
- Added `maxRetryDelayMs` option to `StreamOptions` to cap server-requested retry delays. When a provider (e.g., Google Gemini CLI) requests a delay longer than this value, the request fails immediately with an informative error instead of waiting silently. Default: 60000ms (60 seconds). Set to 0 to disable the cap. ([#1123](https://github.com/jae/agent-jae-cli/issues/1123))
|
||||
- Added Qwen thinking format support for OpenAI-compatible completions via `enable_thinking`. ([#940](https://github.com/jae/agent-jae-cli/pull/940) by [@4h9fbZ](https://github.com/4h9fbZ))
|
||||
|
||||
## [0.50.7] - 2026-01-31
|
||||
|
||||
## [0.50.6] - 2026-01-30
|
||||
|
||||
## [0.50.5] - 2026-01-30
|
||||
|
||||
## [0.50.4] - 2026-01-30
|
||||
|
||||
### Added
|
||||
|
||||
- Added Vercel AI Gateway routing support via `vercelGatewayRouting` option in model config ([#1051](https://github.com/jae/agent-jae-cli/pull/1051) by [@ben-vargas](https://github.com/ben-vargas))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Updated Antigravity User-Agent from 1.11.5 to 1.15.8 to fix rejected requests ([#1079](https://github.com/jae/agent-jae-cli/issues/1079))
|
||||
- Fixed tool call argument defaults for Anthropic and Google history conversion when providers omit inputs ([#1065](https://github.com/jae/agent-jae-cli/issues/1065))
|
||||
|
||||
## [0.50.3] - 2026-01-29
|
||||
|
||||
### Added
|
||||
|
||||
- Added Kimi For Coding provider support (Moonshot AI's Anthropic-compatible coding API)
|
||||
|
||||
## [0.50.2] - 2026-01-29
|
||||
|
||||
### Added
|
||||
|
||||
- Added Hugging Face provider support via OpenAI-compatible Inference Router ([#994](https://github.com/jae/agent-jae-cli/issues/994))
|
||||
- Added `PI_CACHE_RETENTION` environment variable to control cache TTL for Anthropic (5m vs 1h) and OpenAI (in-memory vs 24h). Set to `long` for extended retention. Only applies to direct API calls (api.anthropic.com, api.openai.com). ([#967](https://github.com/jae/agent-jae-cli/issues/967))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed OpenAI completions `toolChoice` handling to correctly set `type: "function"` wrapper ([#998](https://github.com/jae/agent-jae-cli/pull/998) by [@williamtwomey](https://github.com/williamtwomey))
|
||||
- Fixed cross-provider handoff failing when switching from OpenAI Responses API providers (github-copilot, openai-codex) to other providers due to pipe-separated tool call IDs not being normalized, and trailing underscores in truncated IDs being rejected by OpenAI Codex ([#1022](https://github.com/jae/agent-jae-cli/issues/1022))
|
||||
- Fixed 429 rate limit errors incorrectly triggering auto-compaction instead of retry with backoff ([#1038](https://github.com/jae/agent-jae-cli/issues/1038))
|
||||
- Fixed Anthropic provider to handle `sensitive` stop_reason returned by API ([#978](https://github.com/jae/agent-jae-cli/issues/978))
|
||||
- Fixed DeepSeek API compatibility by detecting `deepseek.com` URLs and disabling unsupported `developer` role ([#1048](https://github.com/jae/agent-jae-cli/issues/1048))
|
||||
- Fixed Anthropic provider to preserve input token counts when proxies omit them in `message_delta` events ([#1045](https://github.com/jae/agent-jae-cli/issues/1045))
|
||||
|
||||
## [0.50.1] - 2026-01-26
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed OpenCode Zen model generation to exclude deprecated models ([#970](https://github.com/jae/agent-jae-cli/pull/970) by [@DanielTatarkin](https://github.com/DanielTatarkin))
|
||||
|
||||
## [0.50.0] - 2026-01-26
|
||||
|
||||
### Added
|
||||
|
||||
- Added OpenRouter provider routing support for custom models via `openRouterRouting` compat field ([#859](https://github.com/jae/agent-jae-cli/pull/859) by [@v01dpr1mr0s3](https://github.com/v01dpr1mr0s3))
|
||||
- Added `azure-openai-responses` provider support for Azure OpenAI Responses API. ([#890](https://github.com/jae/agent-jae-cli/pull/890) by [@markusylisiurunen](https://github.com/markusylisiurunen))
|
||||
- Added HTTP proxy environment variable support for API requests ([#942](https://github.com/jae/agent-jae-cli/pull/942) by [@haoqixu](https://github.com/haoqixu))
|
||||
- Added `createAssistantMessageEventStream()` factory function for use in extensions.
|
||||
- Added `resetApiProviders()` to clear and re-register built-in API providers.
|
||||
|
||||
### Changed
|
||||
|
||||
- Refactored API streaming dispatch to use an API registry with provider-owned `streamSimple` mapping.
|
||||
- Moved environment API key resolution to `env-api-keys.ts` and re-exported it from the package entrypoint.
|
||||
- Azure OpenAI Responses provider now uses base URL configuration with deployment-aware model mapping and no longer includes service tier handling.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed Bun runtime detection for dynamic imports in browser-compatible modules (stream.ts, openai-codex-responses.ts, openai-codex.ts) ([#922](https://github.com/jae/agent-jae-cli/pull/922) by [@dannote](https://github.com/dannote))
|
||||
- Fixed streaming functions to use `model.api` instead of hardcoded API types
|
||||
- Fixed Google providers to default tool call arguments to an empty object when omitted
|
||||
- Fixed OpenAI Responses streaming to handle `arguments.done` events on OpenAI-compatible endpoints ([#917](https://github.com/jae/agent-jae-cli/pull/917) by [@williballenthin](https://github.com/williballenthin))
|
||||
- Fixed OpenAI Codex Responses tool strictness handling after the shared responses refactor
|
||||
- Fixed Azure OpenAI Responses streaming to guard deltas before content parts and correct metadata and handoff gating
|
||||
- Fixed OpenAI completions tool-result image batching after consecutive tool results ([#902](https://github.com/jae/agent-jae-cli/pull/902) by [@terrorobe](https://github.com/terrorobe))
|
||||
|
||||
## [0.49.3] - 2026-01-22
|
||||
|
||||
### Added
|
||||
|
||||
- Added `headers` option to `StreamOptions` for custom HTTP headers in API requests. Supported by all providers except Amazon Bedrock (which uses AWS SDK auth). Headers are merged with provider defaults and `model.headers`, with `options.headers` taking precedence.
|
||||
- Added `originator` option to `loginOpenAICodex()` for custom OAuth client identification
|
||||
- Browser compatibility for pi-ai: replaced top-level Node.js imports with dynamic imports for browser environments ([#873](https://github.com/jae/agent-jae-cli/issues/873))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed OpenAI Responses API 400 error "function_call without required reasoning item" when switching between models (same provider, different model). The fix omits the `id` field for function_calls from different models to avoid triggering OpenAI's reasoning/function_call pairing validation ([#886](https://github.com/jae/agent-jae-cli/issues/886))
|
||||
|
||||
## [0.49.2] - 2026-01-19
|
||||
|
||||
### Added
|
||||
|
||||
- Added AWS credential detection for ECS/Kubernetes environments: `AWS_CONTAINER_CREDENTIALS_RELATIVE_URI`, `AWS_CONTAINER_CREDENTIALS_FULL_URI`, `AWS_WEB_IDENTITY_TOKEN_FILE` ([#848](https://github.com/jae/agent-jae-cli/issues/848))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed OpenAI Responses 400 error "reasoning without following item" by skipping errored/aborted assistant messages entirely in transform-messages.ts ([#838](https://github.com/jae/agent-jae-cli/pull/838))
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed `strictResponsesPairing` compat option (no longer needed after the transform-messages fix)
|
||||
|
||||
## [0.49.1] - 2026-01-18
|
||||
|
||||
### Added
|
||||
|
||||
- Added `OpenAIResponsesCompat` interface with `strictResponsesPairing` option for Azure OpenAI Responses API, which requires strict reasoning/message pairing in history replay ([#768](https://github.com/jae/agent-jae-cli/pull/768) by [@prateekmedia](https://github.com/prateekmedia))
|
||||
|
||||
### Changed
|
||||
|
||||
- Split `OpenAICompat` into `OpenAICompletionsCompat` and `OpenAIResponsesCompat` for type-safe API-specific compat settings
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed tool call ID normalization for cross-provider handoffs (e.g., Codex to Antigravity Claude) ([#821](https://github.com/jae/agent-jae-cli/issues/821))
|
||||
|
||||
## [0.49.0] - 2026-01-17
|
||||
|
||||
### Changed
|
||||
|
||||
- OpenAI Codex responses now use the context system prompt directly in the instructions field.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed orphaned tool results after errored assistant messages causing Codex API errors. When an assistant message has `stopReason: "error"`, its tool calls are now excluded from pending tool tracking, preventing synthetic tool results from being generated for calls that will be dropped by provider-specific converters. ([#812](https://github.com/jae/agent-jae-cli/issues/812))
|
||||
- Fixed Bedrock Claude max_tokens handling to always exceed thinking budget tokens, preventing compaction failures. ([#797](https://github.com/jae/agent-jae-cli/pull/797) by [@pjtf93](https://github.com/pjtf93))
|
||||
- Fixed Claude Code tool name normalization to match the Claude Code tool list case-insensitively and remove invalid mappings.
|
||||
|
||||
## [0.48.0] - 2026-01-16
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed OpenAI-compatible provider feature detection to use `model.provider` in addition to URL, allowing custom base URLs (e.g., proxies) to work correctly with provider-specific settings ([#774](https://github.com/jae/agent-jae-cli/issues/774))
|
||||
- Fixed Gemini 3 context loss when switching from providers without thought signatures: unsigned tool calls are now converted to text with anti-mimicry notes instead of being skipped
|
||||
- Fixed string numbers in tool arguments not being coerced to numbers during validation ([#786](https://github.com/jae/agent-jae-cli/pull/786) by [@dannote](https://github.com/dannote))
|
||||
- Fixed Bedrock tool call IDs to use only alphanumeric characters, avoiding API errors from invalid characters ([#781](https://github.com/jae/agent-jae-cli/pull/781) by [@pjtf93](https://github.com/pjtf93))
|
||||
- Fixed empty error assistant messages (from 429/500 errors) breaking the tool_use to tool_result chain by filtering them in `transformMessages`
|
||||
|
||||
## [0.47.0] - 2026-01-16
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed OpenCode provider's `/v1` endpoint to use `system` role instead of `developer` role, fixing `400 Incorrect role information` error for models using `openai-completions` API ([#755](https://github.com/jae/agent-jae-cli/pull/755) by [@melihmucuk](https://github.com/melihmucuk))
|
||||
- Added retry logic to OpenAI Codex provider for transient errors (429, 5xx, connection failures). Uses exponential backoff with up to 3 retries. ([#733](https://github.com/jae/agent-jae-cli/issues/733))
|
||||
|
||||
## [0.46.0] - 2026-01-15
|
||||
|
||||
### Added
|
||||
|
||||
- Added MiniMax China (`minimax-cn`) provider support ([#725](https://github.com/jae/agent-jae-cli/pull/725) by [@tallshort](https://github.com/tallshort))
|
||||
- Added `gpt-5.2-codex` models for GitHub Copilot and OpenCode Zen providers ([#734](https://github.com/jae/agent-jae-cli/pull/734) by [@aadishv](https://github.com/aadishv))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Avoid unsigned Gemini 3 tool calls ([#741](https://github.com/jae/agent-jae-cli/pull/741) by [@roshanasingh4](https://github.com/roshanasingh4))
|
||||
- Fixed signature support for non-Anthropic models in Amazon Bedrock provider ([#727](https://github.com/jae/agent-jae-cli/pull/727) by [@unexge](https://github.com/unexge))
|
||||
|
||||
## [0.45.7] - 2026-01-13
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed OpenAI Responses timeout option handling ([#706](https://github.com/jae/agent-jae-cli/pull/706) by [@markusylisiurunen](https://github.com/markusylisiurunen))
|
||||
- Fixed Bedrock tool call conversion to apply message transforms ([#707](https://github.com/jae/agent-jae-cli/pull/707) by [@pjtf93](https://github.com/pjtf93))
|
||||
|
||||
## [0.45.6] - 2026-01-13
|
||||
|
||||
### Fixed
|
||||
|
||||
- Export `parseStreamingJson` from main package for tsx dev mode compatibility
|
||||
|
||||
## [0.45.5] - 2026-01-13
|
||||
|
||||
## [0.45.4] - 2026-01-13
|
||||
|
||||
### Added
|
||||
|
||||
- Added Vercel AI Gateway provider with model discovery and `AI_GATEWAY_API_KEY` env support ([#689](https://github.com/jae/agent-jae-cli/pull/689) by [@timolins](https://github.com/timolins))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed z.ai thinking/reasoning: z.ai uses `thinking: { type: "enabled" }` instead of OpenAI's `reasoning_effort`. Added `thinkingFormat` compat flag to handle this. ([#688](https://github.com/jae/agent-jae-cli/issues/688))
|
||||
|
||||
## [0.45.3] - 2026-01-13
|
||||
|
||||
## [0.45.2] - 2026-01-13
|
||||
|
||||
## [0.45.1] - 2026-01-13
|
||||
|
||||
## [0.45.0] - 2026-01-13
|
||||
|
||||
### Added
|
||||
|
||||
- MiniMax provider support with M2 and M2.1 models via Anthropic-compatible API ([#656](https://github.com/jae/agent-jae-cli/pull/656) by [@dannote](https://github.com/dannote))
|
||||
- Add Amazon Bedrock provider with prompt caching for Claude models (experimental, tested with Anthropic Claude models only) ([#494](https://github.com/jae/agent-jae-cli/pull/494) by [@unexge](https://github.com/unexge))
|
||||
- Added `serviceTier` option for OpenAI Responses requests ([#672](https://github.com/jae/agent-jae-cli/pull/672) by [@markusylisiurunen](https://github.com/markusylisiurunen))
|
||||
- **Anthropic caching on OpenRouter**: Interactions with Anthropic models via OpenRouter now set a 5-minute cache point using Anthropic-style `cache_control` breakpoints on the last assistant or user message. ([#584](https://github.com/jae/agent-jae-cli/pull/584) by [@nathyong](https://github.com/nathyong))
|
||||
- **Google Gemini CLI provider improvements**: Added Antigravity endpoint fallback (tries daily sandbox then prod when `baseUrl` is unset), header-based retry delay parsing (`Retry-After`, `x-ratelimit-reset`, `x-ratelimit-reset-after`), stable `sessionId` derivation from first user message for cache affinity, empty SSE stream retry with backoff, and `anthropic-beta` header for Claude thinking models ([#670](https://github.com/jae/agent-jae-cli/pull/670) by [@kim0](https://github.com/kim0))
|
||||
|
||||
## [0.44.0] - 2026-01-12
|
||||
|
||||
## [0.43.0] - 2026-01-11
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed Google provider thinking detection: `isThinkingPart()` now only checks `thought === true`, not `thoughtSignature`. Per Google docs, `thoughtSignature` is for context replay and can appear on any part type. Also removed `id` field from `functionCall`/`functionResponse` (rejected by Vertex AI and Cloud Code Assist), and added `textSignature` round-trip for multi-turn reasoning context. ([#631](https://github.com/jae/agent-jae-cli/pull/631) by [@theBucky](https://github.com/theBucky))
|
||||
|
||||
## [0.42.5] - 2026-01-11
|
||||
|
||||
## [0.42.4] - 2026-01-10
|
||||
|
||||
## [0.42.3] - 2026-01-10
|
||||
|
||||
### Changed
|
||||
|
||||
- OpenAI Codex: switched to bundled system prompt matching opencode, changed originator to "pi", simplified prompt handling
|
||||
|
||||
## [0.42.2] - 2026-01-10
|
||||
|
||||
### Added
|
||||
|
||||
- Added `GOOGLE_APPLICATION_CREDENTIALS` env var support for Vertex AI credential detection (standard for CI/production).
|
||||
- Added `supportsUsageInStreaming` compatibility flag for OpenAI-compatible providers that reject `stream_options: { include_usage: true }`. Defaults to `true`. Set to `false` in model config for providers like gatewayz.ai. ([#596](https://github.com/jae/agent-jae-cli/pull/596) by [@XesGaDeus](https://github.com/XesGaDeus))
|
||||
- Improved Google model pricing info ([#588](https://github.com/jae/agent-jae-cli/pull/588) by [@aadishv](https://github.com/aadishv))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed `os.homedir()` calls at module load time; now resolved lazily when needed.
|
||||
- Fixed OpenAI Responses tool strict flag to use a boolean for LM Studio compatibility ([#598](https://github.com/jae/agent-jae-cli/pull/598) by [@gnattu](https://github.com/gnattu))
|
||||
- Fixed Google Cloud Code Assist OAuth for paid subscriptions: properly handles long-running operations for project provisioning, supports `GOOGLE_CLOUD_PROJECT` / `GOOGLE_CLOUD_PROJECT_ID` env vars for paid tiers, and handles VPC-SC affected users ([#582](https://github.com/jae/agent-jae-cli/pull/582) by [@cmf](https://github.com/cmf))
|
||||
|
||||
## [0.42.1] - 2026-01-09
|
||||
|
||||
## [0.42.0] - 2026-01-09
|
||||
|
||||
### Added
|
||||
|
||||
- Added OpenCode Zen provider support with 26 models (Claude, GPT, Gemini, Grok, Kimi, GLM, Qwen, etc.). Set `OPENCODE_API_KEY` env var to use.
|
||||
|
||||
## [0.41.0] - 2026-01-09
|
||||
|
||||
## [0.40.1] - 2026-01-09
|
||||
|
||||
## [0.40.0] - 2026-01-08
|
||||
|
||||
## [0.39.1] - 2026-01-08
|
||||
|
||||
## [0.39.0] - 2026-01-08
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed Gemini CLI abort handling: detect native `AbortError` in retry catch block, cancel SSE reader when abort signal fires ([#568](https://github.com/jae/agent-jae-cli/pull/568) by [@tmustier](https://github.com/tmustier))
|
||||
- Fixed Antigravity provider 429 errors by aligning request payload with CLIProxyAPI v6.6.89: inject Antigravity system instruction with `role: "user"`, set `requestType: "agent"`, and use `antigravity` userAgent. Added bridge prompt to override Antigravity behavior (identity, paths, web dev guidelines) with Pi defaults. ([#571](https://github.com/jae/agent-jae-cli/pull/571) by [@ben-vargas](https://github.com/ben-vargas))
|
||||
- Fixed thinking block handling for cross-model conversations: thinking blocks are now converted to plain text (no `<thinking>` tags) when switching models. Previously, `<thinking>` tags caused models to mimic the pattern and output literal tags. Also fixed empty thinking blocks causing API errors. ([#561](https://github.com/jae/agent-jae-cli/issues/561))
|
||||
|
||||
## [0.38.0] - 2026-01-08
|
||||
|
||||
### Added
|
||||
|
||||
- `thinkingBudgets` option in `SimpleStreamOptions` for customizing token budgets per thinking level on token-based providers ([#529](https://github.com/jae/agent-jae-cli/pull/529) by [@melihmucuk](https://github.com/melihmucuk))
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- Removed OpenAI Codex model aliases (`gpt-5`, `gpt-5-mini`, `gpt-5-nano`, `codex-mini-latest`, `gpt-5-codex`, `gpt-5.1-codex`, `gpt-5.1-chat-latest`). Use canonical model IDs: `gpt-5.1`, `gpt-5.1-codex-max`, `gpt-5.1-codex-mini`, `gpt-5.2`, `gpt-5.2-codex`. ([#536](https://github.com/jae/agent-jae-cli/pull/536) by [@ghoulr](https://github.com/ghoulr))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed OpenAI Codex context window from 400,000 to 272,000 tokens to match Codex CLI defaults and prevent 400 errors. ([#536](https://github.com/jae/agent-jae-cli/pull/536) by [@ghoulr](https://github.com/ghoulr))
|
||||
- Fixed Codex SSE error events to surface message, code, and status. ([#551](https://github.com/jae/agent-jae-cli/pull/551) by [@tmustier](https://github.com/tmustier))
|
||||
- Fixed context overflow detection for `context_length_exceeded` error codes.
|
||||
|
||||
## [0.37.8] - 2026-01-07
|
||||
|
||||
## [0.37.7] - 2026-01-07
|
||||
|
||||
## [0.37.6] - 2026-01-06
|
||||
|
||||
### Added
|
||||
|
||||
- Exported OpenAI Codex utilities: `CacheMetadata`, `getCodexInstructions`, `getModelFamily`, `ModelFamily`, `buildCodexPiBridge`, `buildCodexSystemPrompt`, `CodexSystemPrompt` ([#510](https://github.com/jae/agent-jae-cli/pull/510) by [@mitsuhiko](https://github.com/mitsuhiko))
|
||||
|
||||
## [0.37.5] - 2026-01-06
|
||||
|
||||
## [0.37.4] - 2026-01-06
|
||||
|
||||
## [0.37.3] - 2026-01-06
|
||||
|
||||
### Added
|
||||
|
||||
- `sessionId` option in `StreamOptions` for providers that support session-based caching. OpenAI Codex provider uses this to set `prompt_cache_key` and routing headers.
|
||||
|
||||
## [0.37.2] - 2026-01-05
|
||||
|
||||
### Fixed
|
||||
|
||||
- Codex provider now always includes `reasoning.encrypted_content` even when custom `include` options are passed ([#484](https://github.com/jae/agent-jae-cli/pull/484) by [@kim0](https://github.com/kim0))
|
||||
|
||||
## [0.37.1] - 2026-01-05
|
||||
|
||||
## [0.37.0] - 2026-01-05
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- OpenAI Codex models no longer have per-thinking-level variants (e.g., `gpt-5.2-codex-high`). Use the base model ID and set thinking level separately. The Codex provider clamps reasoning effort to what each model supports internally. (initial implementation by [@ben-vargas](https://github.com/ben-vargas) in [#472](https://github.com/jae/agent-jae-cli/pull/472))
|
||||
|
||||
### Added
|
||||
|
||||
- Headless OAuth support for all callback-server providers (Google Gemini CLI, Antigravity, OpenAI Codex): paste redirect URL when browser callback is unreachable ([#428](https://github.com/jae/agent-jae-cli/pull/428) by [@ben-vargas](https://github.com/ben-vargas), [#468](https://github.com/jae/agent-jae-cli/pull/468) by [@crcatala](https://github.com/crcatala))
|
||||
- Cancellable GitHub Copilot device code polling via AbortSignal
|
||||
|
||||
### Fixed
|
||||
|
||||
- Codex requests now omit the `reasoning` field entirely when thinking is off, letting the backend use its default instead of forcing a value. ([#472](https://github.com/jae/agent-jae-cli/pull/472))
|
||||
|
||||
## [0.36.0] - 2026-01-05
|
||||
|
||||
### Added
|
||||
|
||||
- OpenAI Codex OAuth provider with Responses API streaming support: `openai-codex-responses` streaming provider with SSE parsing, tool-call handling, usage/cost tracking, and PKCE OAuth flow ([#451](https://github.com/jae/agent-jae-cli/pull/451) by [@kim0](https://github.com/kim0))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Vertex AI dummy value for `getEnvApiKey()`: Returns `"<authenticated>"` when Application Default Credentials are configured (`~/.config/gcloud/application_default_credentials.json` exists) and both `GOOGLE_CLOUD_PROJECT` (or `GCLOUD_PROJECT`) and `GOOGLE_CLOUD_LOCATION` are set. This allows `streamSimple()` to work with Vertex AI without explicit `apiKey` option. The ADC credentials file existence check is cached per-process to avoid repeated filesystem access.
|
||||
|
||||
## [0.35.0] - 2026-01-05
|
||||
|
||||
## [0.34.2] - 2026-01-04
|
||||
|
||||
## [0.34.1] - 2026-01-04
|
||||
|
||||
## [0.34.0] - 2026-01-04
|
||||
|
||||
## [0.33.0] - 2026-01-04
|
||||
|
||||
## [0.32.3] - 2026-01-03
|
||||
|
||||
### Fixed
|
||||
|
||||
- Google Vertex AI models no longer appear in available models list without explicit authentication. Previously, `getEnvApiKey()` returned a dummy value for `google-vertex`, causing models to show up even when Google Cloud ADC was not configured.
|
||||
|
||||
## [0.32.2] - 2026-01-03
|
||||
|
||||
## [0.32.1] - 2026-01-03
|
||||
|
||||
## [0.32.0] - 2026-01-03
|
||||
|
||||
### Added
|
||||
|
||||
- Vertex AI provider with ADC (Application Default Credentials) support. Authenticate with `gcloud auth application-default login`, set `GOOGLE_CLOUD_PROJECT` and `GOOGLE_CLOUD_LOCATION`, and access Gemini models via Vertex AI. ([#300](https://github.com/jae/agent-jae-cli/pull/300) by [@default-anton](https://github.com/default-anton))
|
||||
|
||||
### Fixed
|
||||
|
||||
- **Gemini CLI rate limit handling**: Added automatic retry with server-provided delay for 429 errors. Parses delay from error messages like "Your quota will reset after 39s" and waits accordingly. Falls back to exponential backoff for other transient errors. ([#370](https://github.com/jae/agent-jae-cli/issues/370))
|
||||
|
||||
## [0.31.1] - 2026-01-02
|
||||
|
||||
## [0.31.0] - 2026-01-02
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- **Agent API moved**: All agent functionality (`agentLoop`, `agentLoopContinue`, `AgentContext`, `AgentEvent`, `AgentTool`, `AgentToolResult`, etc.) has moved to `@jaeswift/jae-agent-core`. Import from that package instead of `@jaeswift/jae-ai`.
|
||||
|
||||
### Added
|
||||
|
||||
- **`GoogleThinkingLevel` type**: Exported type that mirrors Google's `ThinkingLevel` enum values (`"THINKING_LEVEL_UNSPECIFIED" | "MINIMAL" | "LOW" | "MEDIUM" | "HIGH"`). Allows configuring Gemini thinking levels without importing from `@google/genai`.
|
||||
- **`ANTHROPIC_OAUTH_TOKEN` env var**: Now checked before `ANTHROPIC_API_KEY` in `getEnvApiKey()`, allowing OAuth tokens to take precedence.
|
||||
- **`event-stream.js` export**: `AssistantMessageEventStream` utility now exported from package index.
|
||||
|
||||
### Changed
|
||||
|
||||
- **OAuth uses Web Crypto API**: PKCE generation and OAuth flows now use Web Crypto API (`crypto.subtle`) instead of Node.js `crypto` module. This improves browser compatibility while still working in Node.js 20+.
|
||||
- **Deterministic model generation**: `generate-models.ts` now sorts providers and models alphabetically for consistent output across runs. ([#332](https://github.com/jae/agent-jae-cli/pull/332) by [@mrexodia](https://github.com/mrexodia))
|
||||
|
||||
### Fixed
|
||||
|
||||
- **OpenAI completions empty content blocks**: Empty text or thinking blocks in assistant messages are now filtered out before sending to the OpenAI completions API, preventing validation errors. ([#344](https://github.com/jae/agent-jae-cli/pull/344) by [@default-anton](https://github.com/default-anton))
|
||||
- **Thinking token duplication**: Fixed thinking content duplication with chutes.ai provider. The provider was returning thinking content in both `reasoning_content` and `reasoning` fields, causing each chunk to be processed twice. Now only the first non-empty reasoning field is used.
|
||||
- **zAi provider API mapping**: Fixed zAi models to use `openai-completions` API with correct base URL (`https://api.z.ai/api/coding/paas/v4`) instead of incorrect Anthropic API mapping. ([#344](https://github.com/jae/agent-jae-cli/pull/344), [#358](https://github.com/jae/agent-jae-cli/pull/358) by [@default-anton](https://github.com/default-anton))
|
||||
|
||||
## [0.28.0] - 2025-12-25
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- **OAuth storage removed** ([#296](https://github.com/jae/agent-jae-cli/issues/296)): All storage functions (`loadOAuthCredentials`, `saveOAuthCredentials`, `setOAuthStorage`, etc.) removed. Callers are responsible for storing credentials.
|
||||
- **OAuth login functions**: `loginAnthropic`, `loginGitHubCopilot`, `loginGeminiCli`, `loginAntigravity` now return `OAuthCredentials` instead of saving to disk.
|
||||
- **refreshOAuthToken**: Now takes `(provider, credentials)` and returns new `OAuthCredentials` instead of saving.
|
||||
- **getOAuthApiKey**: Now takes `(provider, credentials)` and returns `{ newCredentials, apiKey }` or null.
|
||||
- **OAuthCredentials type**: No longer includes `type: "oauth"` discriminator. Callers add discriminator when storing.
|
||||
- **setApiKey, resolveApiKey**: Removed. Callers must manage their own API key storage/resolution.
|
||||
- **getApiKey**: Renamed to `getEnvApiKey`. Only checks environment variables for known providers.
|
||||
|
||||
## [0.27.7] - 2025-12-24
|
||||
|
||||
### Fixed
|
||||
|
||||
- **Thinking tag leakage**: Fixed Claude mimicking literal `</thinking>` tags in responses. Unsigned thinking blocks (from aborted streams) are now converted to plain text without `<thinking>` tags. The TUI still displays them as thinking blocks. ([#302](https://github.com/jae/agent-jae-cli/pull/302) by [@nicobailon](https://github.com/nicobailon))
|
||||
|
||||
## [0.25.1] - 2025-12-21
|
||||
|
||||
### Added
|
||||
|
||||
- **xhigh thinking level support**: Added `supportsXhigh()` function to check if a model supports xhigh reasoning level. Also clamps xhigh to high for OpenAI models that don't support it. ([#236](https://github.com/jae/agent-jae-cli/pull/236) by [@theBucky](https://github.com/theBucky))
|
||||
|
||||
### Fixed
|
||||
|
||||
- **Gemini multimodal tool results**: Fixed images in tool results causing flaky/broken responses with Gemini models. For Gemini 3, images are now nested inside `functionResponse.parts` per the [docs](https://ai.google.dev/gemini-api/docs/function-calling#multimodal). For older models (which don't support multimodal function responses), images are sent in a separate user message.
|
||||
|
||||
- **Queued message steering**: When `getQueuedMessages` is provided, the agent loop now checks for queued user messages after each tool call and skips remaining tool calls in the current assistant message when a queued message arrives (emitting error tool results).
|
||||
|
||||
- **Double API version path in Google provider URL**: Fixed Gemini API calls returning 404 after baseUrl support was added. The SDK was appending its default apiVersion to baseUrl which already included the version path. ([#251](https://github.com/jae/agent-jae-cli/pull/251) by [@shellfyred](https://github.com/shellfyred))
|
||||
|
||||
- **Anthropic SDK retries disabled**: Re-enabled SDK-level retries (default 2) for transient HTTP failures. ([#252](https://github.com/jae/agent-jae-cli/issues/252))
|
||||
|
||||
## [0.23.5] - 2025-12-19
|
||||
|
||||
### Added
|
||||
|
||||
- **Gemini 3 Flash thinking support**: Extended thinking level support for Gemini 3 Flash models (MINIMAL, LOW, MEDIUM, HIGH) to match Pro models' capabilities. ([#212](https://github.com/jae/agent-jae-cli/pull/212) by [@markusylisiurunen](https://github.com/markusylisiurunen))
|
||||
|
||||
- **GitHub Copilot thinking models**: Added thinking support for additional Copilot models (o3-mini, o1-mini, o1-preview). ([#234](https://github.com/jae/agent-jae-cli/pull/234) by [@aadishv](https://github.com/aadishv))
|
||||
|
||||
### Fixed
|
||||
|
||||
- **Gemini tool result format**: Fixed tool result format for Gemini 3 Flash Preview which strictly requires `{ output: value }` for success and `{ error: value }` for errors. Previous format using `{ result, isError }` was rejected by newer Gemini models. Also improved type safety by removing `as any` casts. ([#213](https://github.com/jae/agent-jae-cli/issues/213), [#220](https://github.com/jae/agent-jae-cli/pull/220))
|
||||
|
||||
- **Google baseUrl configuration**: Google provider now respects `baseUrl` configuration for custom endpoints or API proxies. ([#216](https://github.com/jae/agent-jae-cli/issues/216), [#221](https://github.com/jae/agent-jae-cli/pull/221) by [@theBucky](https://github.com/theBucky))
|
||||
|
||||
- **GitHub Copilot vision requests**: Added `Copilot-Vision-Request` header when sending images to GitHub Copilot models. ([#222](https://github.com/jae/agent-jae-cli/issues/222))
|
||||
|
||||
- **GitHub Copilot X-Initiator header**: Fixed X-Initiator logic to check last message role instead of any message in history. This ensures proper billing when users send follow-up messages. ([#209](https://github.com/jae/agent-jae-cli/issues/209))
|
||||
|
||||
## [0.22.3] - 2025-12-16
|
||||
|
||||
### Added
|
||||
|
||||
- **Image limits test suite**: Added comprehensive tests for provider-specific image limitations (max images, max size, max dimensions). Discovered actual limits: Anthropic (100 images, 5MB, 8000px), OpenAI (500 images, ≥25MB), Gemini (~2500 images, ≥40MB), Mistral (8 images, ~15MB), OpenRouter (~40 images context-limited, ~15MB). ([#120](https://github.com/jae/agent-jae-cli/pull/120))
|
||||
|
||||
- **Tool result streaming**: Added `tool_execution_update` event and optional `onUpdate` callback to `AgentTool.execute()` for streaming tool output during execution. Tools can now emit partial results (e.g., bash stdout) that are forwarded to subscribers. ([#44](https://github.com/jae/agent-jae-cli/issues/44))
|
||||
|
||||
- **X-Initiator header for GitHub Copilot**: Added X-Initiator header handling for GitHub Copilot provider to ensure correct call accounting (agent calls are not deducted from quota). Sets initiator based on last message role. ([#200](https://github.com/jae/agent-jae-cli/pull/200) by [@kim0](https://github.com/kim0))
|
||||
|
||||
### Changed
|
||||
|
||||
- **Normalized tool_execution_end result**: `tool_execution_end` event now always contains `AgentToolResult` (no longer `AgentToolResult | string`). Errors are wrapped in the standard result format.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **Reasoning disabled by default**: When `reasoning` option is not specified, thinking is now explicitly disabled for all providers. Previously, some providers like Gemini with "dynamic thinking" would use their default (thinking ON), causing unexpected token usage. This was the original intended behavior. ([#180](https://github.com/jae/agent-jae-cli/pull/180) by [@markusylisiurunen](https://github.com/markusylisiurunen))
|
||||
|
||||
## [0.22.2] - 2025-12-15
|
||||
|
||||
### Added
|
||||
|
||||
- **Interleaved thinking for Anthropic**: Added `interleavedThinking` option to `AnthropicOptions`. When enabled, Claude 4 models can think between tool calls and reason after receiving tool results. Enabled by default (no extra token cost, just unlocks the capability). Set `interleavedThinking: false` to disable.
|
||||
|
||||
## [0.22.1] - 2025-12-15
|
||||
|
||||
_Dedicated to Peter's shoulder ([@steipete](https://twitter.com/steipete))_
|
||||
|
||||
### Added
|
||||
|
||||
- **Interleaved thinking for Anthropic**: Enabled interleaved thinking in the Anthropic provider, allowing Claude models to output thinking blocks interspersed with text responses.
|
||||
|
||||
## [0.22.0] - 2025-12-15
|
||||
|
||||
### Added
|
||||
|
||||
- **GitHub Copilot provider**: Added `github-copilot` as a known provider with models sourced from models.dev. Includes Claude, GPT, Gemini, Grok, and other models available through GitHub Copilot. ([#191](https://github.com/jae/agent-jae-cli/pull/191) by [@cau1k](https://github.com/cau1k))
|
||||
|
||||
### Fixed
|
||||
|
||||
- **GitHub Copilot gpt-5 models**: Fixed API selection for gpt-5 models to use `openai-responses` instead of `openai-completions` (gpt-5 models are not accessible via completions endpoint)
|
||||
|
||||
- **GitHub Copilot cross-model context handoff**: Fixed context handoff failing when switching between GitHub Copilot models using different APIs (e.g., gpt-5 to claude-sonnet-4). Tool call IDs from OpenAI Responses API were incompatible with other models. ([#198](https://github.com/jae/agent-jae-cli/issues/198))
|
||||
|
||||
- **Gemini 3 Pro thinking levels**: Thinking level configuration now works correctly for Gemini 3 Pro models. Previously all levels mapped to -1 (minimal thinking). Now LOW/MEDIUM/HIGH properly control test-time computation. ([#176](https://github.com/jae/agent-jae-cli/pull/176) by [@markusylisiurunen](https://github.com/markusylisiurunen))
|
||||
|
||||
## [0.18.2] - 2025-12-11
|
||||
|
||||
### Changed
|
||||
|
||||
- **Anthropic SDK retries disabled**: Set `maxRetries: 0` on Anthropic client to allow application-level retry handling. The SDK's built-in retries were interfering with coding-agent's retry logic. ([#157](https://github.com/jae/agent-jae-cli/issues/157))
|
||||
|
||||
## [0.18.1] - 2025-12-10
|
||||
|
||||
### Added
|
||||
|
||||
- **Mistral provider**: Added support for Mistral AI models via the OpenAI-compatible API. Includes automatic handling of Mistral-specific requirements (tool call ID format). Set `MISTRAL_API_KEY` environment variable to use.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed Mistral 400 errors after aborted assistant messages by skipping empty assistant messages (no content, no tool calls) ([#165](https://github.com/jae/agent-jae-cli/issues/165))
|
||||
|
||||
- Removed synthetic assistant bridge message after tool results for Mistral (no longer required as of Dec 2025) ([#165](https://github.com/jae/agent-jae-cli/issues/165))
|
||||
|
||||
- Fixed bug where `ANTHROPIC_API_KEY` environment variable was deleted globally after first OAuth token usage, causing subsequent prompts to fail ([#164](https://github.com/jae/agent-jae-cli/pull/164))
|
||||
|
||||
## [0.17.0] - 2025-12-09
|
||||
|
||||
### Added
|
||||
|
||||
- **`agentLoopContinue` function**: Continue an agent loop from existing context without adding a new user message. Validates that the last message is `user` or `toolResult`. Useful for retry after context overflow or resuming from manually-added tool results.
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- Removed provider-level tool argument validation. Validation now happens in `agentLoop` via `executeToolCalls`, allowing models to retry on validation errors. For manual tool execution, use `validateToolCall(tools, toolCall)` or `validateToolArguments(tool, toolCall)`.
|
||||
|
||||
### Added
|
||||
|
||||
- Added `validateToolCall(tools, toolCall)` helper that finds the tool by name and validates arguments.
|
||||
|
||||
- **OpenAI compatibility overrides**: Added `compat` field to `Model` for `openai-completions` API, allowing explicit configuration of provider quirks (`supportsStore`, `supportsDeveloperRole`, `supportsReasoningEffort`, `maxTokensField`). Falls back to URL-based detection if not set. Useful for LiteLLM, custom proxies, and other non-standard endpoints. ([#133](https://github.com/jae/agent-jae-cli/issues/133), thanks @fink-andreas for the initial idea and PR)
|
||||
|
||||
- **xhigh reasoning level**: Added `xhigh` to `ReasoningEffort` type for OpenAI codex-max models. For non-OpenAI providers (Anthropic, Google), `xhigh` is automatically mapped to `high`. ([#143](https://github.com/jae/agent-jae-cli/issues/143))
|
||||
|
||||
### Changed
|
||||
|
||||
- **Updated SDK versions**: OpenAI SDK 5.21.0 → 6.10.0, Anthropic SDK 0.61.0 → 0.71.2, Google GenAI SDK 1.30.0 → 1.31.0
|
||||
|
||||
## [0.13.0] - 2025-12-06
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- **Added `totalTokens` field to `Usage` type**: All code that constructs `Usage` objects must now include the `totalTokens` field. This field represents the total tokens processed by the LLM (input + output + cache). For OpenAI and Google, this uses native API values (`total_tokens`, `totalTokenCount`). For Anthropic, it's computed as `input + output + cacheRead + cacheWrite`.
|
||||
|
||||
## [0.12.10] - 2025-12-04
|
||||
|
||||
### Added
|
||||
|
||||
- Added `gpt-5.1-codex-max` model support
|
||||
|
||||
### Fixed
|
||||
|
||||
- **OpenAI Token Counting**: Fixed `usage.input` to exclude cached tokens for OpenAI providers. Previously, `input` included cached tokens, causing double-counting when calculating total context size via `input + cacheRead`. Now `input` represents non-cached input tokens across all providers, making `input + output + cacheRead + cacheWrite` the correct formula for total context size.
|
||||
|
||||
- **Fixed Claude Opus 4.5 cache pricing** (was 3x too expensive)
|
||||
- Corrected cache_read: $1.50 → $0.50 per MTok
|
||||
- Corrected cache_write: $18.75 → $6.25 per MTok
|
||||
- Added manual override in `scripts/generate-models.ts` until upstream fix is merged
|
||||
- Submitted PR to models.dev: https://github.com/sst/models.dev/pull/439
|
||||
|
||||
## [0.9.4] - 2025-11-26
|
||||
|
||||
Initial release with multi-provider LLM support.
|
||||
1229
packages/ai/README.md
Normal file
1229
packages/ai/README.md
Normal file
File diff suppressed because it is too large
Load diff
1
packages/ai/bedrock-provider.d.ts
vendored
Normal file
1
packages/ai/bedrock-provider.d.ts
vendored
Normal file
|
|
@ -0,0 +1 @@
|
|||
export * from "./dist/bedrock-provider.js";
|
||||
1
packages/ai/bedrock-provider.js
Normal file
1
packages/ai/bedrock-provider.js
Normal file
|
|
@ -0,0 +1 @@
|
|||
export * from "./dist/bedrock-provider.js";
|
||||
114
packages/ai/package.json
Normal file
114
packages/ai/package.json
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
{
|
||||
"name": "@jaeswift/jae-ai",
|
||||
"version": "0.62.0",
|
||||
"description": "Unified LLM API with automatic model discovery and provider configuration",
|
||||
"type": "module",
|
||||
"main": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts",
|
||||
"exports": {
|
||||
".": {
|
||||
"types": "./dist/index.d.ts",
|
||||
"import": "./dist/index.js"
|
||||
},
|
||||
"./anthropic": {
|
||||
"types": "./dist/providers/anthropic.d.ts",
|
||||
"import": "./dist/providers/anthropic.js"
|
||||
},
|
||||
"./azure-openai-responses": {
|
||||
"types": "./dist/providers/azure-openai-responses.d.ts",
|
||||
"import": "./dist/providers/azure-openai-responses.js"
|
||||
},
|
||||
"./google": {
|
||||
"types": "./dist/providers/google.d.ts",
|
||||
"import": "./dist/providers/google.js"
|
||||
},
|
||||
"./google-gemini-cli": {
|
||||
"types": "./dist/providers/google-gemini-cli.d.ts",
|
||||
"import": "./dist/providers/google-gemini-cli.js"
|
||||
},
|
||||
"./google-vertex": {
|
||||
"types": "./dist/providers/google-vertex.d.ts",
|
||||
"import": "./dist/providers/google-vertex.js"
|
||||
},
|
||||
"./mistral": {
|
||||
"types": "./dist/providers/mistral.d.ts",
|
||||
"import": "./dist/providers/mistral.js"
|
||||
},
|
||||
"./openai-codex-responses": {
|
||||
"types": "./dist/providers/openai-codex-responses.d.ts",
|
||||
"import": "./dist/providers/openai-codex-responses.js"
|
||||
},
|
||||
"./openai-completions": {
|
||||
"types": "./dist/providers/openai-completions.d.ts",
|
||||
"import": "./dist/providers/openai-completions.js"
|
||||
},
|
||||
"./openai-responses": {
|
||||
"types": "./dist/providers/openai-responses.d.ts",
|
||||
"import": "./dist/providers/openai-responses.js"
|
||||
},
|
||||
"./oauth": {
|
||||
"types": "./dist/oauth.d.ts",
|
||||
"import": "./dist/oauth.js"
|
||||
},
|
||||
"./bedrock-provider": {
|
||||
"types": "./dist/bedrock-provider.d.ts",
|
||||
"import": "./dist/bedrock-provider.js"
|
||||
}
|
||||
},
|
||||
"bin": {
|
||||
"jae-ai": "./dist/cli.js"
|
||||
},
|
||||
"files": [
|
||||
"dist",
|
||||
"README.md"
|
||||
],
|
||||
"scripts": {
|
||||
"clean": "shx rm -rf dist",
|
||||
"generate-models": "npx tsx scripts/generate-models.ts",
|
||||
"build": "npm run generate-models && tsgo -p tsconfig.build.json",
|
||||
"dev": "tsgo -p tsconfig.build.json --watch --preserveWatchOutput",
|
||||
"dev:tsc": "tsgo -p tsconfig.build.json --watch --preserveWatchOutput",
|
||||
"test": "vitest --run",
|
||||
"prepublishOnly": "npm run clean && npm run build"
|
||||
},
|
||||
"dependencies": {
|
||||
"@anthropic-ai/sdk": "^0.73.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.983.0",
|
||||
"@google/genai": "^1.40.0",
|
||||
"@mistralai/mistralai": "1.14.1",
|
||||
"@sinclair/typebox": "^0.34.41",
|
||||
"ajv": "^8.17.1",
|
||||
"ajv-formats": "^3.0.1",
|
||||
"chalk": "^5.6.2",
|
||||
"openai": "6.26.0",
|
||||
"partial-json": "^0.1.7",
|
||||
"proxy-agent": "^6.5.0",
|
||||
"undici": "^7.19.1",
|
||||
"zod-to-json-schema": "^3.24.6"
|
||||
},
|
||||
"keywords": [
|
||||
"ai",
|
||||
"llm",
|
||||
"openai",
|
||||
"anthropic",
|
||||
"gemini",
|
||||
"bedrock",
|
||||
"unified",
|
||||
"api"
|
||||
],
|
||||
"author": "Mario Zechner",
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/jae/agent-jae-cli.git",
|
||||
"directory": "packages/ai"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^24.3.0",
|
||||
"canvas": "^3.2.0",
|
||||
"vitest": "^3.2.4"
|
||||
}
|
||||
}
|
||||
1593
packages/ai/scripts/generate-models.ts
Normal file
1593
packages/ai/scripts/generate-models.ts
Normal file
File diff suppressed because it is too large
Load diff
34
packages/ai/scripts/generate-test-image.ts
Normal file
34
packages/ai/scripts/generate-test-image.ts
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
#!/usr/bin/env tsx
|
||||
|
||||
import { createCanvas } from "canvas";
|
||||
import { writeFileSync } from "fs";
|
||||
import { join, dirname } from "path";
|
||||
import { fileURLToPath } from "url";
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
||||
// Create a 200x200 canvas
|
||||
const canvas = createCanvas(200, 200);
|
||||
const ctx = canvas.getContext("2d");
|
||||
|
||||
// Fill background with white
|
||||
ctx.fillStyle = "white";
|
||||
ctx.fillRect(0, 0, 200, 200);
|
||||
|
||||
// Draw a red circle in the center
|
||||
ctx.fillStyle = "red";
|
||||
ctx.beginPath();
|
||||
ctx.arc(100, 100, 50, 0, Math.PI * 2);
|
||||
ctx.fill();
|
||||
|
||||
// Save the image
|
||||
const buffer = canvas.toBuffer("image/png");
|
||||
const outputPath = join(__dirname, "..", "test", "data", "red-circle.png");
|
||||
|
||||
// Ensure the directory exists
|
||||
import { mkdirSync } from "fs";
|
||||
mkdirSync(join(__dirname, "..", "test", "data"), { recursive: true });
|
||||
|
||||
writeFileSync(outputPath, buffer);
|
||||
console.log(`Generated test image at: ${outputPath}`);
|
||||
98
packages/ai/src/api-registry.ts
Normal file
98
packages/ai/src/api-registry.ts
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
import type {
|
||||
Api,
|
||||
AssistantMessageEventStream,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
} from "./types.js";
|
||||
|
||||
export type ApiStreamFunction = (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: StreamOptions,
|
||||
) => AssistantMessageEventStream;
|
||||
|
||||
export type ApiStreamSimpleFunction = (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
) => AssistantMessageEventStream;
|
||||
|
||||
export interface ApiProvider<TApi extends Api = Api, TOptions extends StreamOptions = StreamOptions> {
|
||||
api: TApi;
|
||||
stream: StreamFunction<TApi, TOptions>;
|
||||
streamSimple: StreamFunction<TApi, SimpleStreamOptions>;
|
||||
}
|
||||
|
||||
interface ApiProviderInternal {
|
||||
api: Api;
|
||||
stream: ApiStreamFunction;
|
||||
streamSimple: ApiStreamSimpleFunction;
|
||||
}
|
||||
|
||||
type RegisteredApiProvider = {
|
||||
provider: ApiProviderInternal;
|
||||
sourceId?: string;
|
||||
};
|
||||
|
||||
const apiProviderRegistry = new Map<string, RegisteredApiProvider>();
|
||||
|
||||
function wrapStream<TApi extends Api, TOptions extends StreamOptions>(
|
||||
api: TApi,
|
||||
stream: StreamFunction<TApi, TOptions>,
|
||||
): ApiStreamFunction {
|
||||
return (model, context, options) => {
|
||||
if (model.api !== api) {
|
||||
throw new Error(`Mismatched api: ${model.api} expected ${api}`);
|
||||
}
|
||||
return stream(model as Model<TApi>, context, options as TOptions);
|
||||
};
|
||||
}
|
||||
|
||||
function wrapStreamSimple<TApi extends Api>(
|
||||
api: TApi,
|
||||
streamSimple: StreamFunction<TApi, SimpleStreamOptions>,
|
||||
): ApiStreamSimpleFunction {
|
||||
return (model, context, options) => {
|
||||
if (model.api !== api) {
|
||||
throw new Error(`Mismatched api: ${model.api} expected ${api}`);
|
||||
}
|
||||
return streamSimple(model as Model<TApi>, context, options);
|
||||
};
|
||||
}
|
||||
|
||||
export function registerApiProvider<TApi extends Api, TOptions extends StreamOptions>(
|
||||
provider: ApiProvider<TApi, TOptions>,
|
||||
sourceId?: string,
|
||||
): void {
|
||||
apiProviderRegistry.set(provider.api, {
|
||||
provider: {
|
||||
api: provider.api,
|
||||
stream: wrapStream(provider.api, provider.stream),
|
||||
streamSimple: wrapStreamSimple(provider.api, provider.streamSimple),
|
||||
},
|
||||
sourceId,
|
||||
});
|
||||
}
|
||||
|
||||
export function getApiProvider(api: Api): ApiProviderInternal | undefined {
|
||||
return apiProviderRegistry.get(api)?.provider;
|
||||
}
|
||||
|
||||
export function getApiProviders(): ApiProviderInternal[] {
|
||||
return Array.from(apiProviderRegistry.values(), (entry) => entry.provider);
|
||||
}
|
||||
|
||||
export function unregisterApiProviders(sourceId: string): void {
|
||||
for (const [api, entry] of apiProviderRegistry.entries()) {
|
||||
if (entry.sourceId === sourceId) {
|
||||
apiProviderRegistry.delete(api);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function clearApiProviders(): void {
|
||||
apiProviderRegistry.clear();
|
||||
}
|
||||
6
packages/ai/src/bedrock-provider.ts
Normal file
6
packages/ai/src/bedrock-provider.ts
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
import { streamBedrock, streamSimpleBedrock } from "./providers/amazon-bedrock.js";
|
||||
|
||||
export const bedrockProviderModule = {
|
||||
streamBedrock,
|
||||
streamSimpleBedrock,
|
||||
};
|
||||
133
packages/ai/src/cli.ts
Normal file
133
packages/ai/src/cli.ts
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
#!/usr/bin/env node
|
||||
|
||||
import { existsSync, readFileSync, writeFileSync } from "fs";
|
||||
import { createInterface } from "readline";
|
||||
import { getOAuthProvider, getOAuthProviders } from "./utils/oauth/index.js";
|
||||
import type { OAuthCredentials, OAuthProviderId } from "./utils/oauth/types.js";
|
||||
|
||||
const AUTH_FILE = "auth.json";
|
||||
const PROVIDERS = getOAuthProviders();
|
||||
|
||||
function prompt(rl: ReturnType<typeof createInterface>, question: string): Promise<string> {
|
||||
return new Promise((resolve) => rl.question(question, resolve));
|
||||
}
|
||||
|
||||
function loadAuth(): Record<string, { type: "oauth" } & OAuthCredentials> {
|
||||
if (!existsSync(AUTH_FILE)) return {};
|
||||
try {
|
||||
return JSON.parse(readFileSync(AUTH_FILE, "utf-8"));
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
function saveAuth(auth: Record<string, { type: "oauth" } & OAuthCredentials>): void {
|
||||
writeFileSync(AUTH_FILE, JSON.stringify(auth, null, 2), "utf-8");
|
||||
}
|
||||
|
||||
async function login(providerId: OAuthProviderId): Promise<void> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
console.error(`Unknown provider: ${providerId}`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
const rl = createInterface({ input: process.stdin, output: process.stdout });
|
||||
const promptFn = (msg: string) => prompt(rl, `${msg} `);
|
||||
|
||||
try {
|
||||
const credentials = await provider.login({
|
||||
onAuth: (info) => {
|
||||
console.log(`\nOpen this URL in your browser:\n${info.url}`);
|
||||
if (info.instructions) console.log(info.instructions);
|
||||
console.log();
|
||||
},
|
||||
onPrompt: async (p) => {
|
||||
return await promptFn(`${p.message}${p.placeholder ? ` (${p.placeholder})` : ""}:`);
|
||||
},
|
||||
onProgress: (msg) => console.log(msg),
|
||||
});
|
||||
|
||||
const auth = loadAuth();
|
||||
auth[providerId] = { type: "oauth", ...credentials };
|
||||
saveAuth(auth);
|
||||
|
||||
console.log(`\nCredentials saved to ${AUTH_FILE}`);
|
||||
} finally {
|
||||
rl.close();
|
||||
}
|
||||
}
|
||||
|
||||
async function main(): Promise<void> {
|
||||
const args = process.argv.slice(2);
|
||||
const command = args[0];
|
||||
|
||||
if (!command || command === "help" || command === "--help" || command === "-h") {
|
||||
const providerList = PROVIDERS.map((p) => ` ${p.id.padEnd(20)} ${p.name}`).join("\n");
|
||||
console.log(`Usage: npx @jaeswift/jae-ai <command> [provider]
|
||||
|
||||
Commands:
|
||||
login [provider] Login to an OAuth provider
|
||||
list List available providers
|
||||
|
||||
Providers:
|
||||
${providerList}
|
||||
|
||||
Examples:
|
||||
npx @jaeswift/jae-ai login # interactive provider selection
|
||||
npx @jaeswift/jae-ai login anthropic # login to specific provider
|
||||
npx @jaeswift/jae-ai list # list providers
|
||||
`);
|
||||
return;
|
||||
}
|
||||
|
||||
if (command === "list") {
|
||||
console.log("Available OAuth providers:\n");
|
||||
for (const p of PROVIDERS) {
|
||||
console.log(` ${p.id.padEnd(20)} ${p.name}`);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (command === "login") {
|
||||
let provider = args[1] as OAuthProviderId | undefined;
|
||||
|
||||
if (!provider) {
|
||||
const rl = createInterface({ input: process.stdin, output: process.stdout });
|
||||
console.log("Select a provider:\n");
|
||||
for (let i = 0; i < PROVIDERS.length; i++) {
|
||||
console.log(` ${i + 1}. ${PROVIDERS[i].name}`);
|
||||
}
|
||||
console.log();
|
||||
|
||||
const choice = await prompt(rl, `Enter number (1-${PROVIDERS.length}): `);
|
||||
rl.close();
|
||||
|
||||
const index = parseInt(choice, 10) - 1;
|
||||
if (index < 0 || index >= PROVIDERS.length) {
|
||||
console.error("Invalid selection");
|
||||
process.exit(1);
|
||||
}
|
||||
provider = PROVIDERS[index].id;
|
||||
}
|
||||
|
||||
if (!PROVIDERS.some((p) => p.id === provider)) {
|
||||
console.error(`Unknown provider: ${provider}`);
|
||||
console.error(`Use 'npx @jaeswift/jae-ai list' to see available providers`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
console.log(`Logging in to ${provider}...`);
|
||||
await login(provider);
|
||||
return;
|
||||
}
|
||||
|
||||
console.error(`Unknown command: ${command}`);
|
||||
console.error(`Use 'npx @jaeswift/jae-ai --help' for usage`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
main().catch((err) => {
|
||||
console.error("Error:", err.message);
|
||||
process.exit(1);
|
||||
});
|
||||
133
packages/ai/src/env-api-keys.ts
Normal file
133
packages/ai/src/env-api-keys.ts
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
// NEVER convert to top-level imports - breaks browser/Vite builds (web-ui)
|
||||
let _existsSync: typeof import("node:fs").existsSync | null = null;
|
||||
let _homedir: typeof import("node:os").homedir | null = null;
|
||||
let _join: typeof import("node:path").join | null = null;
|
||||
|
||||
type DynamicImport = (specifier: string) => Promise<unknown>;
|
||||
|
||||
const dynamicImport: DynamicImport = (specifier) => import(specifier);
|
||||
const NODE_FS_SPECIFIER = "node:" + "fs";
|
||||
const NODE_OS_SPECIFIER = "node:" + "os";
|
||||
const NODE_PATH_SPECIFIER = "node:" + "path";
|
||||
|
||||
// Eagerly load in Node.js/Bun environment only
|
||||
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
|
||||
dynamicImport(NODE_FS_SPECIFIER).then((m) => {
|
||||
_existsSync = (m as typeof import("node:fs")).existsSync;
|
||||
});
|
||||
dynamicImport(NODE_OS_SPECIFIER).then((m) => {
|
||||
_homedir = (m as typeof import("node:os")).homedir;
|
||||
});
|
||||
dynamicImport(NODE_PATH_SPECIFIER).then((m) => {
|
||||
_join = (m as typeof import("node:path")).join;
|
||||
});
|
||||
}
|
||||
|
||||
import type { KnownProvider } from "./types.js";
|
||||
|
||||
let cachedVertexAdcCredentialsExists: boolean | null = null;
|
||||
|
||||
function hasVertexAdcCredentials(): boolean {
|
||||
if (cachedVertexAdcCredentialsExists === null) {
|
||||
// If node modules haven't loaded yet (async import race at startup),
|
||||
// return false WITHOUT caching so the next call retries once they're ready.
|
||||
// Only cache false permanently in a browser environment where fs is never available.
|
||||
if (!_existsSync || !_homedir || !_join) {
|
||||
const isNode = typeof process !== "undefined" && (process.versions?.node || process.versions?.bun);
|
||||
if (!isNode) {
|
||||
// Definitively in a browser — safe to cache false permanently
|
||||
cachedVertexAdcCredentialsExists = false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check GOOGLE_APPLICATION_CREDENTIALS env var first (standard way)
|
||||
const gacPath = process.env.GOOGLE_APPLICATION_CREDENTIALS;
|
||||
if (gacPath) {
|
||||
cachedVertexAdcCredentialsExists = _existsSync(gacPath);
|
||||
} else {
|
||||
// Fall back to default ADC path (lazy evaluation)
|
||||
cachedVertexAdcCredentialsExists = _existsSync(
|
||||
_join(_homedir(), ".config", "gcloud", "application_default_credentials.json"),
|
||||
);
|
||||
}
|
||||
}
|
||||
return cachedVertexAdcCredentialsExists;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for provider from known environment variables, e.g. OPENAI_API_KEY.
|
||||
*
|
||||
* Will not return API keys for providers that require OAuth tokens.
|
||||
*/
|
||||
export function getEnvApiKey(provider: KnownProvider): string | undefined;
|
||||
export function getEnvApiKey(provider: string): string | undefined;
|
||||
export function getEnvApiKey(provider: any): string | undefined {
|
||||
// Fall back to environment variables
|
||||
if (provider === "github-copilot") {
|
||||
return process.env.COPILOT_GITHUB_TOKEN || process.env.GH_TOKEN || process.env.GITHUB_TOKEN;
|
||||
}
|
||||
|
||||
// ANTHROPIC_OAUTH_TOKEN takes precedence over ANTHROPIC_API_KEY
|
||||
if (provider === "anthropic") {
|
||||
return process.env.ANTHROPIC_OAUTH_TOKEN || process.env.ANTHROPIC_API_KEY;
|
||||
}
|
||||
|
||||
// Vertex AI supports either an explicit API key or Application Default Credentials
|
||||
// Auth is configured via `gcloud auth application-default login`
|
||||
if (provider === "google-vertex") {
|
||||
if (process.env.GOOGLE_CLOUD_API_KEY) {
|
||||
return process.env.GOOGLE_CLOUD_API_KEY;
|
||||
}
|
||||
|
||||
const hasCredentials = hasVertexAdcCredentials();
|
||||
const hasProject = !!(process.env.GOOGLE_CLOUD_PROJECT || process.env.GCLOUD_PROJECT);
|
||||
const hasLocation = !!process.env.GOOGLE_CLOUD_LOCATION;
|
||||
|
||||
if (hasCredentials && hasProject && hasLocation) {
|
||||
return "<authenticated>";
|
||||
}
|
||||
}
|
||||
|
||||
if (provider === "amazon-bedrock") {
|
||||
// Amazon Bedrock supports multiple credential sources:
|
||||
// 1. AWS_PROFILE - named profile from ~/.aws/credentials
|
||||
// 2. AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY - standard IAM keys
|
||||
// 3. AWS_BEARER_TOKEN_BEDROCK - Bedrock API keys (bearer token)
|
||||
// 4. AWS_CONTAINER_CREDENTIALS_RELATIVE_URI - ECS task roles
|
||||
// 5. AWS_CONTAINER_CREDENTIALS_FULL_URI - ECS task roles (full URI)
|
||||
// 6. AWS_WEB_IDENTITY_TOKEN_FILE - IRSA (IAM Roles for Service Accounts)
|
||||
if (
|
||||
process.env.AWS_PROFILE ||
|
||||
(process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY) ||
|
||||
process.env.AWS_BEARER_TOKEN_BEDROCK ||
|
||||
process.env.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI ||
|
||||
process.env.AWS_CONTAINER_CREDENTIALS_FULL_URI ||
|
||||
process.env.AWS_WEB_IDENTITY_TOKEN_FILE
|
||||
) {
|
||||
return "<authenticated>";
|
||||
}
|
||||
}
|
||||
|
||||
const envMap: Record<string, string> = {
|
||||
openai: "OPENAI_API_KEY",
|
||||
"azure-openai-responses": "AZURE_OPENAI_API_KEY",
|
||||
google: "GEMINI_API_KEY",
|
||||
groq: "GROQ_API_KEY",
|
||||
cerebras: "CEREBRAS_API_KEY",
|
||||
xai: "XAI_API_KEY",
|
||||
openrouter: "OPENROUTER_API_KEY",
|
||||
"vercel-ai-gateway": "AI_GATEWAY_API_KEY",
|
||||
zai: "ZAI_API_KEY",
|
||||
mistral: "MISTRAL_API_KEY",
|
||||
minimax: "MINIMAX_API_KEY",
|
||||
"minimax-cn": "MINIMAX_CN_API_KEY",
|
||||
huggingface: "HF_TOKEN",
|
||||
opencode: "OPENCODE_API_KEY",
|
||||
"opencode-go": "OPENCODE_API_KEY",
|
||||
"kimi-coding": "KIMI_API_KEY",
|
||||
};
|
||||
|
||||
const envVar = envMap[provider];
|
||||
return envVar ? process.env[envVar] : undefined;
|
||||
}
|
||||
34
packages/ai/src/index.ts
Normal file
34
packages/ai/src/index.ts
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
export type { Static, TSchema } from "@sinclair/typebox";
|
||||
export { Type } from "@sinclair/typebox";
|
||||
|
||||
export * from "./api-registry.js";
|
||||
export * from "./env-api-keys.js";
|
||||
export * from "./models.js";
|
||||
export type { BedrockOptions } from "./providers/amazon-bedrock.js";
|
||||
export type { AnthropicOptions } from "./providers/anthropic.js";
|
||||
export type { AzureOpenAIResponsesOptions } from "./providers/azure-openai-responses.js";
|
||||
export type { GoogleOptions } from "./providers/google.js";
|
||||
export type { GoogleGeminiCliOptions, GoogleThinkingLevel } from "./providers/google-gemini-cli.js";
|
||||
export type { GoogleVertexOptions } from "./providers/google-vertex.js";
|
||||
export type { MistralOptions } from "./providers/mistral.js";
|
||||
export type { OpenAICodexResponsesOptions } from "./providers/openai-codex-responses.js";
|
||||
export type { OpenAICompletionsOptions } from "./providers/openai-completions.js";
|
||||
export type { OpenAIResponsesOptions } from "./providers/openai-responses.js";
|
||||
export * from "./providers/register-builtins.js";
|
||||
export * from "./stream.js";
|
||||
export * from "./types.js";
|
||||
export * from "./utils/event-stream.js";
|
||||
export * from "./utils/json-parse.js";
|
||||
export type {
|
||||
OAuthAuthInfo,
|
||||
OAuthCredentials,
|
||||
OAuthLoginCallbacks,
|
||||
OAuthPrompt,
|
||||
OAuthProvider,
|
||||
OAuthProviderId,
|
||||
OAuthProviderInfo,
|
||||
OAuthProviderInterface,
|
||||
} from "./utils/oauth/types.js";
|
||||
export * from "./utils/overflow.js";
|
||||
export * from "./utils/typebox-helpers.js";
|
||||
export * from "./utils/validation.js";
|
||||
13899
packages/ai/src/models.generated.ts
Normal file
13899
packages/ai/src/models.generated.ts
Normal file
File diff suppressed because it is too large
Load diff
77
packages/ai/src/models.ts
Normal file
77
packages/ai/src/models.ts
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import { MODELS } from "./models.generated.js";
|
||||
import type { Api, KnownProvider, Model, Usage } from "./types.js";
|
||||
|
||||
const modelRegistry: Map<string, Map<string, Model<Api>>> = new Map();
|
||||
|
||||
// Initialize registry from MODELS on module load
|
||||
for (const [provider, models] of Object.entries(MODELS)) {
|
||||
const providerModels = new Map<string, Model<Api>>();
|
||||
for (const [id, model] of Object.entries(models)) {
|
||||
providerModels.set(id, model as Model<Api>);
|
||||
}
|
||||
modelRegistry.set(provider, providerModels);
|
||||
}
|
||||
|
||||
type ModelApi<
|
||||
TProvider extends KnownProvider,
|
||||
TModelId extends keyof (typeof MODELS)[TProvider],
|
||||
> = (typeof MODELS)[TProvider][TModelId] extends { api: infer TApi } ? (TApi extends Api ? TApi : never) : never;
|
||||
|
||||
export function getModel<TProvider extends KnownProvider, TModelId extends keyof (typeof MODELS)[TProvider]>(
|
||||
provider: TProvider,
|
||||
modelId: TModelId,
|
||||
): Model<ModelApi<TProvider, TModelId>> {
|
||||
const providerModels = modelRegistry.get(provider);
|
||||
return providerModels?.get(modelId as string) as Model<ModelApi<TProvider, TModelId>>;
|
||||
}
|
||||
|
||||
export function getProviders(): KnownProvider[] {
|
||||
return Array.from(modelRegistry.keys()) as KnownProvider[];
|
||||
}
|
||||
|
||||
export function getModels<TProvider extends KnownProvider>(
|
||||
provider: TProvider,
|
||||
): Model<ModelApi<TProvider, keyof (typeof MODELS)[TProvider]>>[] {
|
||||
const models = modelRegistry.get(provider);
|
||||
return models ? (Array.from(models.values()) as Model<ModelApi<TProvider, keyof (typeof MODELS)[TProvider]>>[]) : [];
|
||||
}
|
||||
|
||||
export function calculateCost<TApi extends Api>(model: Model<TApi>, usage: Usage): Usage["cost"] {
|
||||
usage.cost.input = (model.cost.input / 1000000) * usage.input;
|
||||
usage.cost.output = (model.cost.output / 1000000) * usage.output;
|
||||
usage.cost.cacheRead = (model.cost.cacheRead / 1000000) * usage.cacheRead;
|
||||
usage.cost.cacheWrite = (model.cost.cacheWrite / 1000000) * usage.cacheWrite;
|
||||
usage.cost.total = usage.cost.input + usage.cost.output + usage.cost.cacheRead + usage.cost.cacheWrite;
|
||||
return usage.cost;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model supports xhigh thinking level.
|
||||
*
|
||||
* Supported today:
|
||||
* - GPT-5.2 / GPT-5.3 / GPT-5.4 model families
|
||||
* - Opus 4.6 models (xhigh maps to adaptive effort "max" on Anthropic-compatible providers)
|
||||
*/
|
||||
export function supportsXhigh<TApi extends Api>(model: Model<TApi>): boolean {
|
||||
if (model.id.includes("gpt-5.2") || model.id.includes("gpt-5.3") || model.id.includes("gpt-5.4")) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (model.id.includes("opus-4-6") || model.id.includes("opus-4.6")) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if two models are equal by comparing both their id and provider.
|
||||
* Returns false if either model is null or undefined.
|
||||
*/
|
||||
export function modelsAreEqual<TApi extends Api>(
|
||||
a: Model<TApi> | null | undefined,
|
||||
b: Model<TApi> | null | undefined,
|
||||
): boolean {
|
||||
if (!a || !b) return false;
|
||||
return a.id === b.id && a.provider === b.provider;
|
||||
}
|
||||
1
packages/ai/src/oauth.ts
Normal file
1
packages/ai/src/oauth.ts
Normal file
|
|
@ -0,0 +1 @@
|
|||
export * from "./utils/oauth/index.js";
|
||||
776
packages/ai/src/providers/amazon-bedrock.ts
Normal file
776
packages/ai/src/providers/amazon-bedrock.ts
Normal file
|
|
@ -0,0 +1,776 @@
|
|||
import {
|
||||
BedrockRuntimeClient,
|
||||
type BedrockRuntimeClientConfig,
|
||||
StopReason as BedrockStopReason,
|
||||
type Tool as BedrockTool,
|
||||
CachePointType,
|
||||
CacheTTL,
|
||||
type ContentBlock,
|
||||
type ContentBlockDeltaEvent,
|
||||
type ContentBlockStartEvent,
|
||||
type ContentBlockStopEvent,
|
||||
ConversationRole,
|
||||
ConverseStreamCommand,
|
||||
type ConverseStreamMetadataEvent,
|
||||
ImageFormat,
|
||||
type Message,
|
||||
type SystemContentBlock,
|
||||
type ToolChoice,
|
||||
type ToolConfiguration,
|
||||
ToolResultStatus,
|
||||
} from "@aws-sdk/client-bedrock-runtime";
|
||||
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
CacheRetention,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingBudgets,
|
||||
ThinkingContent,
|
||||
ThinkingLevel,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolResultMessage,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import { adjustMaxTokensForThinking, buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
export interface BedrockOptions extends StreamOptions {
|
||||
region?: string;
|
||||
profile?: string;
|
||||
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
|
||||
/* See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-reasoning.html for supported models. */
|
||||
reasoning?: ThinkingLevel;
|
||||
/* Custom token budgets per thinking level. Overrides default budgets. */
|
||||
thinkingBudgets?: ThinkingBudgets;
|
||||
/* Only supported by Claude 4.x models, see https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html#claude-messages-extended-thinking-tool-use-interleaved */
|
||||
interleavedThinking?: boolean;
|
||||
/** Key-value pairs attached to the inference request for cost allocation tagging.
|
||||
* Keys: max 64 chars, no `aws:` prefix. Values: max 256 chars. Max 50 pairs.
|
||||
* Tags appear in AWS Cost Explorer split cost allocation data.
|
||||
* @see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html */
|
||||
requestMetadata?: Record<string, string>;
|
||||
}
|
||||
|
||||
type Block = (TextContent | ThinkingContent | ToolCall) & { index?: number; partialJson?: string };
|
||||
|
||||
export const streamBedrock: StreamFunction<"bedrock-converse-stream", BedrockOptions> = (
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options: BedrockOptions = {},
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "bedrock-converse-stream" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const blocks = output.content as Block[];
|
||||
|
||||
const config: BedrockRuntimeClientConfig = {
|
||||
profile: options.profile,
|
||||
};
|
||||
|
||||
// in Node.js/Bun environment only
|
||||
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
|
||||
// Region resolution: explicit option > env vars > SDK default chain.
|
||||
// When AWS_PROFILE is set, we leave region undefined so the SDK can
|
||||
// resovle it from aws profile configs. Otherwise fall back to us-east-1.
|
||||
const explicitRegion = options.region || process.env.AWS_REGION || process.env.AWS_DEFAULT_REGION;
|
||||
if (explicitRegion) {
|
||||
config.region = explicitRegion;
|
||||
} else if (!process.env.AWS_PROFILE) {
|
||||
config.region = "us-east-1";
|
||||
}
|
||||
|
||||
// Support proxies that don't need authentication
|
||||
if (process.env.AWS_BEDROCK_SKIP_AUTH === "1") {
|
||||
config.credentials = {
|
||||
accessKeyId: "dummy-access-key",
|
||||
secretAccessKey: "dummy-secret-key",
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
process.env.HTTP_PROXY ||
|
||||
process.env.HTTPS_PROXY ||
|
||||
process.env.NO_PROXY ||
|
||||
process.env.http_proxy ||
|
||||
process.env.https_proxy ||
|
||||
process.env.no_proxy
|
||||
) {
|
||||
const nodeHttpHandler = await import("@smithy/node-http-handler");
|
||||
const proxyAgent = await import("proxy-agent");
|
||||
|
||||
const agent = new proxyAgent.ProxyAgent();
|
||||
|
||||
// Bedrock runtime uses NodeHttp2Handler by default since v3.798.0, which is based
|
||||
// on `http2` module and has no support for http agent.
|
||||
// Use NodeHttpHandler to support http agent.
|
||||
config.requestHandler = new nodeHttpHandler.NodeHttpHandler({
|
||||
httpAgent: agent,
|
||||
httpsAgent: agent,
|
||||
});
|
||||
} else if (process.env.AWS_BEDROCK_FORCE_HTTP1 === "1") {
|
||||
// Some custom endpoints require HTTP/1.1 instead of HTTP/2
|
||||
const nodeHttpHandler = await import("@smithy/node-http-handler");
|
||||
config.requestHandler = new nodeHttpHandler.NodeHttpHandler();
|
||||
}
|
||||
} else {
|
||||
// Non-Node environment (browser): fall back to us-east-1 since
|
||||
// there's no config file resolution available.
|
||||
config.region = options.region || "us-east-1";
|
||||
}
|
||||
|
||||
try {
|
||||
const client = new BedrockRuntimeClient(config);
|
||||
|
||||
const cacheRetention = resolveCacheRetention(options.cacheRetention);
|
||||
let commandInput = {
|
||||
modelId: model.id,
|
||||
messages: convertMessages(context, model, cacheRetention),
|
||||
system: buildSystemPrompt(context.systemPrompt, model, cacheRetention),
|
||||
inferenceConfig: { maxTokens: options.maxTokens, temperature: options.temperature },
|
||||
toolConfig: convertToolConfig(context.tools, options.toolChoice),
|
||||
additionalModelRequestFields: buildAdditionalModelRequestFields(model, options),
|
||||
...(options.requestMetadata !== undefined && { requestMetadata: options.requestMetadata }),
|
||||
};
|
||||
const nextCommandInput = await options?.onPayload?.(commandInput, model);
|
||||
if (nextCommandInput !== undefined) {
|
||||
commandInput = nextCommandInput as typeof commandInput;
|
||||
}
|
||||
const command = new ConverseStreamCommand(commandInput);
|
||||
|
||||
const response = await client.send(command, { abortSignal: options.signal });
|
||||
|
||||
for await (const item of response.stream!) {
|
||||
if (item.messageStart) {
|
||||
if (item.messageStart.role !== ConversationRole.ASSISTANT) {
|
||||
throw new Error("Unexpected assistant message start but got user message start instead");
|
||||
}
|
||||
stream.push({ type: "start", partial: output });
|
||||
} else if (item.contentBlockStart) {
|
||||
handleContentBlockStart(item.contentBlockStart, blocks, output, stream);
|
||||
} else if (item.contentBlockDelta) {
|
||||
handleContentBlockDelta(item.contentBlockDelta, blocks, output, stream);
|
||||
} else if (item.contentBlockStop) {
|
||||
handleContentBlockStop(item.contentBlockStop, blocks, output, stream);
|
||||
} else if (item.messageStop) {
|
||||
output.stopReason = mapStopReason(item.messageStop.stopReason);
|
||||
} else if (item.metadata) {
|
||||
handleMetadata(item.metadata, model, output);
|
||||
} else if (item.internalServerException) {
|
||||
throw new Error(`Internal server error: ${item.internalServerException.message}`);
|
||||
} else if (item.modelStreamErrorException) {
|
||||
throw new Error(`Model stream error: ${item.modelStreamErrorException.message}`);
|
||||
} else if (item.validationException) {
|
||||
throw new Error(`Validation error: ${item.validationException.message}`);
|
||||
} else if (item.throttlingException) {
|
||||
throw new Error(`Throttling error: ${item.throttlingException.message}`);
|
||||
} else if (item.serviceUnavailableException) {
|
||||
throw new Error(`Service unavailable: ${item.serviceUnavailableException.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "error" || output.stopReason === "aborted") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) {
|
||||
delete (block as Block).index;
|
||||
delete (block as Block).partialJson;
|
||||
}
|
||||
output.stopReason = options.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleBedrock: StreamFunction<"bedrock-converse-stream", SimpleStreamOptions> = (
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const base = buildBaseOptions(model, options, undefined);
|
||||
if (!options?.reasoning) {
|
||||
return streamBedrock(model, context, { ...base, reasoning: undefined } satisfies BedrockOptions);
|
||||
}
|
||||
|
||||
if (model.id.includes("anthropic.claude") || model.id.includes("anthropic/claude")) {
|
||||
if (supportsAdaptiveThinking(model.id)) {
|
||||
return streamBedrock(model, context, {
|
||||
...base,
|
||||
reasoning: options.reasoning,
|
||||
thinkingBudgets: options.thinkingBudgets,
|
||||
} satisfies BedrockOptions);
|
||||
}
|
||||
|
||||
const adjusted = adjustMaxTokensForThinking(
|
||||
base.maxTokens || 0,
|
||||
model.maxTokens,
|
||||
options.reasoning,
|
||||
options.thinkingBudgets,
|
||||
);
|
||||
|
||||
return streamBedrock(model, context, {
|
||||
...base,
|
||||
maxTokens: adjusted.maxTokens,
|
||||
reasoning: options.reasoning,
|
||||
thinkingBudgets: {
|
||||
...(options.thinkingBudgets || {}),
|
||||
[clampReasoning(options.reasoning)!]: adjusted.thinkingBudget,
|
||||
},
|
||||
} satisfies BedrockOptions);
|
||||
}
|
||||
|
||||
return streamBedrock(model, context, {
|
||||
...base,
|
||||
reasoning: options.reasoning,
|
||||
thinkingBudgets: options.thinkingBudgets,
|
||||
} satisfies BedrockOptions);
|
||||
};
|
||||
|
||||
function handleContentBlockStart(
|
||||
event: ContentBlockStartEvent,
|
||||
blocks: Block[],
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
): void {
|
||||
const index = event.contentBlockIndex!;
|
||||
const start = event.start;
|
||||
|
||||
if (start?.toolUse) {
|
||||
const block: Block = {
|
||||
type: "toolCall",
|
||||
id: start.toolUse.toolUseId || "",
|
||||
name: start.toolUse.name || "",
|
||||
arguments: {},
|
||||
partialJson: "",
|
||||
index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({ type: "toolcall_start", contentIndex: blocks.length - 1, partial: output });
|
||||
}
|
||||
}
|
||||
|
||||
function handleContentBlockDelta(
|
||||
event: ContentBlockDeltaEvent,
|
||||
blocks: Block[],
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
): void {
|
||||
const contentBlockIndex = event.contentBlockIndex!;
|
||||
const delta = event.delta;
|
||||
let index = blocks.findIndex((b) => b.index === contentBlockIndex);
|
||||
let block = blocks[index];
|
||||
|
||||
if (delta?.text !== undefined) {
|
||||
// If no text block exists yet, create one, as `handleContentBlockStart` is not sent for text blocks
|
||||
if (!block) {
|
||||
const newBlock: Block = { type: "text", text: "", index: contentBlockIndex };
|
||||
output.content.push(newBlock);
|
||||
index = blocks.length - 1;
|
||||
block = blocks[index];
|
||||
stream.push({ type: "text_start", contentIndex: index, partial: output });
|
||||
}
|
||||
if (block.type === "text") {
|
||||
block.text += delta.text;
|
||||
stream.push({ type: "text_delta", contentIndex: index, delta: delta.text, partial: output });
|
||||
}
|
||||
} else if (delta?.toolUse && block?.type === "toolCall") {
|
||||
block.partialJson = (block.partialJson || "") + (delta.toolUse.input || "");
|
||||
block.arguments = parseStreamingJson(block.partialJson);
|
||||
stream.push({ type: "toolcall_delta", contentIndex: index, delta: delta.toolUse.input || "", partial: output });
|
||||
} else if (delta?.reasoningContent) {
|
||||
let thinkingBlock = block;
|
||||
let thinkingIndex = index;
|
||||
|
||||
if (!thinkingBlock) {
|
||||
const newBlock: Block = { type: "thinking", thinking: "", thinkingSignature: "", index: contentBlockIndex };
|
||||
output.content.push(newBlock);
|
||||
thinkingIndex = blocks.length - 1;
|
||||
thinkingBlock = blocks[thinkingIndex];
|
||||
stream.push({ type: "thinking_start", contentIndex: thinkingIndex, partial: output });
|
||||
}
|
||||
|
||||
if (thinkingBlock?.type === "thinking") {
|
||||
if (delta.reasoningContent.text) {
|
||||
thinkingBlock.thinking += delta.reasoningContent.text;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: thinkingIndex,
|
||||
delta: delta.reasoningContent.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
if (delta.reasoningContent.signature) {
|
||||
thinkingBlock.thinkingSignature =
|
||||
(thinkingBlock.thinkingSignature || "") + delta.reasoningContent.signature;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleMetadata(
|
||||
event: ConverseStreamMetadataEvent,
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
output: AssistantMessage,
|
||||
): void {
|
||||
if (event.usage) {
|
||||
output.usage.input = event.usage.inputTokens || 0;
|
||||
output.usage.output = event.usage.outputTokens || 0;
|
||||
output.usage.cacheRead = event.usage.cacheReadInputTokens || 0;
|
||||
output.usage.cacheWrite = event.usage.cacheWriteInputTokens || 0;
|
||||
output.usage.totalTokens = event.usage.totalTokens || output.usage.input + output.usage.output;
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
function handleContentBlockStop(
|
||||
event: ContentBlockStopEvent,
|
||||
blocks: Block[],
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
): void {
|
||||
const index = blocks.findIndex((b) => b.index === event.contentBlockIndex);
|
||||
const block = blocks[index];
|
||||
if (!block) return;
|
||||
delete (block as Block).index;
|
||||
|
||||
switch (block.type) {
|
||||
case "text":
|
||||
stream.push({ type: "text_end", contentIndex: index, content: block.text, partial: output });
|
||||
break;
|
||||
case "thinking":
|
||||
stream.push({ type: "thinking_end", contentIndex: index, content: block.thinking, partial: output });
|
||||
break;
|
||||
case "toolCall":
|
||||
block.arguments = parseStreamingJson(block.partialJson);
|
||||
delete (block as Block).partialJson;
|
||||
stream.push({ type: "toolcall_end", contentIndex: index, toolCall: block, partial: output });
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model supports adaptive thinking (Opus 4.6 and Sonnet 4.6).
|
||||
*/
|
||||
function supportsAdaptiveThinking(modelId: string): boolean {
|
||||
return (
|
||||
modelId.includes("opus-4-6") ||
|
||||
modelId.includes("opus-4.6") ||
|
||||
modelId.includes("sonnet-4-6") ||
|
||||
modelId.includes("sonnet-4.6")
|
||||
);
|
||||
}
|
||||
|
||||
function mapThinkingLevelToEffort(
|
||||
level: SimpleStreamOptions["reasoning"],
|
||||
modelId: string,
|
||||
): "low" | "medium" | "high" | "max" {
|
||||
switch (level) {
|
||||
case "minimal":
|
||||
case "low":
|
||||
return "low";
|
||||
case "medium":
|
||||
return "medium";
|
||||
case "high":
|
||||
return "high";
|
||||
case "xhigh":
|
||||
return modelId.includes("opus-4-6") || modelId.includes("opus-4.6") ? "max" : "high";
|
||||
default:
|
||||
return "high";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve cache retention preference.
|
||||
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||
*/
|
||||
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model supports prompt caching.
|
||||
* Supported: Claude 3.5 Haiku, Claude 3.7 Sonnet, Claude 4.x models
|
||||
*
|
||||
* For base models and system-defined inference profiles the model ID / ARN
|
||||
* contains the model name, so we can decide locally.
|
||||
*
|
||||
* For application inference profiles (whose ARNs don't contain the model name),
|
||||
* set AWS_BEDROCK_FORCE_CACHE=1 to enable cache points. Amazon Nova models
|
||||
* have automatic caching and don't need explicit cache points.
|
||||
*/
|
||||
function supportsPromptCaching(model: Model<"bedrock-converse-stream">): boolean {
|
||||
const id = model.id.toLowerCase();
|
||||
if (!id.includes("claude")) {
|
||||
// Application inference profiles don't contain the model name in the ARN.
|
||||
// Allow users to force cache points via environment variable.
|
||||
if (typeof process !== "undefined" && process.env.AWS_BEDROCK_FORCE_CACHE === "1") return true;
|
||||
return false;
|
||||
}
|
||||
// Claude 4.x models (opus-4, sonnet-4, haiku-4)
|
||||
if (id.includes("-4-") || id.includes("-4.")) return true;
|
||||
// Claude 3.7 Sonnet
|
||||
if (id.includes("claude-3-7-sonnet")) return true;
|
||||
// Claude 3.5 Haiku
|
||||
if (id.includes("claude-3-5-haiku")) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model supports thinking signatures in reasoningContent.
|
||||
* Only Anthropic Claude models support the signature field.
|
||||
* Other models (OpenAI, Qwen, Minimax, Moonshot, etc.) reject it with:
|
||||
* "This model doesn't support the reasoningContent.reasoningText.signature field"
|
||||
*/
|
||||
function supportsThinkingSignature(model: Model<"bedrock-converse-stream">): boolean {
|
||||
const id = model.id.toLowerCase();
|
||||
return id.includes("anthropic.claude") || id.includes("anthropic/claude");
|
||||
}
|
||||
|
||||
function buildSystemPrompt(
|
||||
systemPrompt: string | undefined,
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
cacheRetention: CacheRetention,
|
||||
): SystemContentBlock[] | undefined {
|
||||
if (!systemPrompt) return undefined;
|
||||
|
||||
const blocks: SystemContentBlock[] = [{ text: sanitizeSurrogates(systemPrompt) }];
|
||||
|
||||
// Add cache point for supported Claude models when caching is enabled
|
||||
if (cacheRetention !== "none" && supportsPromptCaching(model)) {
|
||||
blocks.push({
|
||||
cachePoint: { type: CachePointType.DEFAULT, ...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}) },
|
||||
});
|
||||
}
|
||||
|
||||
return blocks;
|
||||
}
|
||||
|
||||
function normalizeToolCallId(id: string): string {
|
||||
const sanitized = id.replace(/[^a-zA-Z0-9_-]/g, "_");
|
||||
return sanitized.length > 64 ? sanitized.slice(0, 64) : sanitized;
|
||||
}
|
||||
|
||||
function convertMessages(
|
||||
context: Context,
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
cacheRetention: CacheRetention,
|
||||
): Message[] {
|
||||
const result: Message[] = [];
|
||||
const transformedMessages = transformMessages(context.messages, model, normalizeToolCallId);
|
||||
|
||||
for (let i = 0; i < transformedMessages.length; i++) {
|
||||
const m = transformedMessages[i];
|
||||
|
||||
switch (m.role) {
|
||||
case "user":
|
||||
result.push({
|
||||
role: ConversationRole.USER,
|
||||
content:
|
||||
typeof m.content === "string"
|
||||
? [{ text: sanitizeSurrogates(m.content) }]
|
||||
: m.content.map((c) => {
|
||||
switch (c.type) {
|
||||
case "text":
|
||||
return { text: sanitizeSurrogates(c.text) };
|
||||
case "image":
|
||||
return { image: createImageBlock(c.mimeType, c.data) };
|
||||
default:
|
||||
throw new Error("Unknown user content type");
|
||||
}
|
||||
}),
|
||||
});
|
||||
break;
|
||||
case "assistant": {
|
||||
// Skip assistant messages with empty content (e.g., from aborted requests)
|
||||
// Bedrock rejects messages with empty content arrays
|
||||
if (m.content.length === 0) {
|
||||
continue;
|
||||
}
|
||||
const contentBlocks: ContentBlock[] = [];
|
||||
for (const c of m.content) {
|
||||
switch (c.type) {
|
||||
case "text":
|
||||
// Skip empty text blocks
|
||||
if (c.text.trim().length === 0) continue;
|
||||
contentBlocks.push({ text: sanitizeSurrogates(c.text) });
|
||||
break;
|
||||
case "toolCall":
|
||||
contentBlocks.push({
|
||||
toolUse: { toolUseId: c.id, name: c.name, input: c.arguments },
|
||||
});
|
||||
break;
|
||||
case "thinking":
|
||||
// Skip empty thinking blocks
|
||||
if (c.thinking.trim().length === 0) continue;
|
||||
// Only Anthropic models support the signature field in reasoningText.
|
||||
// For other models, we omit the signature to avoid errors like:
|
||||
// "This model doesn't support the reasoningContent.reasoningText.signature field"
|
||||
if (supportsThinkingSignature(model)) {
|
||||
// Signatures arrive after thinking deltas. If a partial or externally
|
||||
// persisted message lacks a signature, Bedrock rejects the replayed
|
||||
// reasoning block. Fall back to plain text, matching Anthropic.
|
||||
if (!c.thinkingSignature || c.thinkingSignature.trim().length === 0) {
|
||||
contentBlocks.push({ text: sanitizeSurrogates(c.thinking) });
|
||||
} else {
|
||||
contentBlocks.push({
|
||||
reasoningContent: {
|
||||
reasoningText: {
|
||||
text: sanitizeSurrogates(c.thinking),
|
||||
signature: c.thinkingSignature,
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
} else {
|
||||
contentBlocks.push({
|
||||
reasoningContent: {
|
||||
reasoningText: { text: sanitizeSurrogates(c.thinking) },
|
||||
},
|
||||
});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw new Error("Unknown assistant content type");
|
||||
}
|
||||
}
|
||||
// Skip if all content blocks were filtered out
|
||||
if (contentBlocks.length === 0) {
|
||||
continue;
|
||||
}
|
||||
result.push({
|
||||
role: ConversationRole.ASSISTANT,
|
||||
content: contentBlocks,
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "toolResult": {
|
||||
// Collect all consecutive toolResult messages into a single user message
|
||||
// Bedrock requires all tool results to be in one message
|
||||
const toolResults: ContentBlock.ToolResultMember[] = [];
|
||||
|
||||
// Add current tool result with all content blocks combined
|
||||
toolResults.push({
|
||||
toolResult: {
|
||||
toolUseId: m.toolCallId,
|
||||
content: m.content.map((c) =>
|
||||
c.type === "image"
|
||||
? { image: createImageBlock(c.mimeType, c.data) }
|
||||
: { text: sanitizeSurrogates(c.text) },
|
||||
),
|
||||
status: m.isError ? ToolResultStatus.ERROR : ToolResultStatus.SUCCESS,
|
||||
},
|
||||
});
|
||||
|
||||
// Look ahead for consecutive toolResult messages
|
||||
let j = i + 1;
|
||||
while (j < transformedMessages.length && transformedMessages[j].role === "toolResult") {
|
||||
const nextMsg = transformedMessages[j] as ToolResultMessage;
|
||||
toolResults.push({
|
||||
toolResult: {
|
||||
toolUseId: nextMsg.toolCallId,
|
||||
content: nextMsg.content.map((c) =>
|
||||
c.type === "image"
|
||||
? { image: createImageBlock(c.mimeType, c.data) }
|
||||
: { text: sanitizeSurrogates(c.text) },
|
||||
),
|
||||
status: nextMsg.isError ? ToolResultStatus.ERROR : ToolResultStatus.SUCCESS,
|
||||
},
|
||||
});
|
||||
j++;
|
||||
}
|
||||
|
||||
// Skip the messages we've already processed
|
||||
i = j - 1;
|
||||
|
||||
result.push({
|
||||
role: ConversationRole.USER,
|
||||
content: toolResults,
|
||||
});
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw new Error("Unknown message role");
|
||||
}
|
||||
}
|
||||
|
||||
// Add cache point to the last user message for supported Claude models when caching is enabled
|
||||
if (cacheRetention !== "none" && supportsPromptCaching(model) && result.length > 0) {
|
||||
const lastMessage = result[result.length - 1];
|
||||
if (lastMessage.role === ConversationRole.USER && lastMessage.content) {
|
||||
(lastMessage.content as ContentBlock[]).push({
|
||||
cachePoint: {
|
||||
type: CachePointType.DEFAULT,
|
||||
...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
function convertToolConfig(
|
||||
tools: Tool[] | undefined,
|
||||
toolChoice: BedrockOptions["toolChoice"],
|
||||
): ToolConfiguration | undefined {
|
||||
if (!tools?.length || toolChoice === "none") return undefined;
|
||||
|
||||
const bedrockTools: BedrockTool[] = tools.map((tool) => ({
|
||||
toolSpec: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
inputSchema: { json: tool.parameters },
|
||||
},
|
||||
}));
|
||||
|
||||
let bedrockToolChoice: ToolChoice | undefined;
|
||||
switch (toolChoice) {
|
||||
case "auto":
|
||||
bedrockToolChoice = { auto: {} };
|
||||
break;
|
||||
case "any":
|
||||
bedrockToolChoice = { any: {} };
|
||||
break;
|
||||
default:
|
||||
if (toolChoice?.type === "tool") {
|
||||
bedrockToolChoice = { tool: { name: toolChoice.name } };
|
||||
}
|
||||
}
|
||||
|
||||
return { tools: bedrockTools, toolChoice: bedrockToolChoice };
|
||||
}
|
||||
|
||||
function mapStopReason(reason: string | undefined): StopReason {
|
||||
switch (reason) {
|
||||
case BedrockStopReason.END_TURN:
|
||||
case BedrockStopReason.STOP_SEQUENCE:
|
||||
return "stop";
|
||||
case BedrockStopReason.MAX_TOKENS:
|
||||
case BedrockStopReason.MODEL_CONTEXT_WINDOW_EXCEEDED:
|
||||
return "length";
|
||||
case BedrockStopReason.TOOL_USE:
|
||||
return "toolUse";
|
||||
default:
|
||||
return "error";
|
||||
}
|
||||
}
|
||||
|
||||
function buildAdditionalModelRequestFields(
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
options: BedrockOptions,
|
||||
): Record<string, any> | undefined {
|
||||
if (!options.reasoning || !model.reasoning) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (model.id.includes("anthropic.claude") || model.id.includes("anthropic/claude")) {
|
||||
const result: Record<string, any> = supportsAdaptiveThinking(model.id)
|
||||
? {
|
||||
thinking: { type: "adaptive" },
|
||||
output_config: { effort: mapThinkingLevelToEffort(options.reasoning, model.id) },
|
||||
}
|
||||
: (() => {
|
||||
const defaultBudgets: Record<ThinkingLevel, number> = {
|
||||
minimal: 1024,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 16384,
|
||||
xhigh: 16384, // Claude doesn't support xhigh, clamp to high
|
||||
};
|
||||
|
||||
// Custom budgets override defaults (xhigh not in ThinkingBudgets, use high)
|
||||
const level = options.reasoning === "xhigh" ? "high" : options.reasoning;
|
||||
const budget = options.thinkingBudgets?.[level] ?? defaultBudgets[options.reasoning];
|
||||
|
||||
return {
|
||||
thinking: {
|
||||
type: "enabled",
|
||||
budget_tokens: budget,
|
||||
},
|
||||
};
|
||||
})();
|
||||
|
||||
if (!supportsAdaptiveThinking(model.id) && (options.interleavedThinking ?? true)) {
|
||||
result.anthropic_beta = ["interleaved-thinking-2025-05-14"];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function createImageBlock(mimeType: string, data: string) {
|
||||
let format: ImageFormat;
|
||||
switch (mimeType) {
|
||||
case "image/jpeg":
|
||||
case "image/jpg":
|
||||
format = ImageFormat.JPEG;
|
||||
break;
|
||||
case "image/png":
|
||||
format = ImageFormat.PNG;
|
||||
break;
|
||||
case "image/gif":
|
||||
format = ImageFormat.GIF;
|
||||
break;
|
||||
case "image/webp":
|
||||
format = ImageFormat.WEBP;
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unknown image type: ${mimeType}`);
|
||||
}
|
||||
|
||||
const binaryString = atob(data);
|
||||
const bytes = new Uint8Array(binaryString.length);
|
||||
for (let i = 0; i < binaryString.length; i++) {
|
||||
bytes[i] = binaryString.charCodeAt(i);
|
||||
}
|
||||
|
||||
return { source: { bytes }, format };
|
||||
}
|
||||
905
packages/ai/src/providers/anthropic.ts
Normal file
905
packages/ai/src/providers/anthropic.ts
Normal file
|
|
@ -0,0 +1,905 @@
|
|||
import Anthropic from "@anthropic-ai/sdk";
|
||||
import type {
|
||||
ContentBlockParam,
|
||||
MessageCreateParamsStreaming,
|
||||
MessageParam,
|
||||
} from "@anthropic-ai/sdk/resources/messages.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
CacheRetention,
|
||||
Context,
|
||||
ImageContent,
|
||||
Message,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolResultMessage,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
|
||||
import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./github-copilot-headers.js";
|
||||
import { adjustMaxTokensForThinking, buildBaseOptions } from "./simple-options.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
/**
|
||||
* Resolve cache retention preference.
|
||||
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||
*/
|
||||
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
}
|
||||
|
||||
function getCacheControl(
|
||||
baseUrl: string,
|
||||
cacheRetention?: CacheRetention,
|
||||
): { retention: CacheRetention; cacheControl?: { type: "ephemeral"; ttl?: "1h" } } {
|
||||
const retention = resolveCacheRetention(cacheRetention);
|
||||
if (retention === "none") {
|
||||
return { retention };
|
||||
}
|
||||
const ttl = retention === "long" && baseUrl.includes("api.anthropic.com") ? "1h" : undefined;
|
||||
return {
|
||||
retention,
|
||||
cacheControl: { type: "ephemeral", ...(ttl && { ttl }) },
|
||||
};
|
||||
}
|
||||
|
||||
// Stealth mode: Mimic Claude Code's tool naming exactly
|
||||
const claudeCodeVersion = "2.1.75";
|
||||
|
||||
// Claude Code 2.x tool names (canonical casing)
|
||||
// Source: https://cchistory.mariozechner.at/data/prompts-2.1.11.md
|
||||
// To update: https://github.com/badlogic/cchistory
|
||||
const claudeCodeTools = [
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
"Bash",
|
||||
"Grep",
|
||||
"Glob",
|
||||
"AskUserQuestion",
|
||||
"EnterPlanMode",
|
||||
"ExitPlanMode",
|
||||
"KillShell",
|
||||
"NotebookEdit",
|
||||
"Skill",
|
||||
"Task",
|
||||
"TaskOutput",
|
||||
"TodoWrite",
|
||||
"WebFetch",
|
||||
"WebSearch",
|
||||
];
|
||||
|
||||
const ccToolLookup = new Map(claudeCodeTools.map((t) => [t.toLowerCase(), t]));
|
||||
|
||||
// Convert tool name to CC canonical casing if it matches (case-insensitive)
|
||||
const toClaudeCodeName = (name: string) => ccToolLookup.get(name.toLowerCase()) ?? name;
|
||||
const fromClaudeCodeName = (name: string, tools?: Tool[]) => {
|
||||
if (tools && tools.length > 0) {
|
||||
const lowerName = name.toLowerCase();
|
||||
const matchedTool = tools.find((tool) => tool.name.toLowerCase() === lowerName);
|
||||
if (matchedTool) return matchedTool.name;
|
||||
}
|
||||
return name;
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert content blocks to Anthropic API format
|
||||
*/
|
||||
function convertContentBlocks(content: (TextContent | ImageContent)[]):
|
||||
| string
|
||||
| Array<
|
||||
| { type: "text"; text: string }
|
||||
| {
|
||||
type: "image";
|
||||
source: {
|
||||
type: "base64";
|
||||
media_type: "image/jpeg" | "image/png" | "image/gif" | "image/webp";
|
||||
data: string;
|
||||
};
|
||||
}
|
||||
> {
|
||||
// If only text blocks, return as concatenated string for simplicity
|
||||
const hasImages = content.some((c) => c.type === "image");
|
||||
if (!hasImages) {
|
||||
return sanitizeSurrogates(content.map((c) => (c as TextContent).text).join("\n"));
|
||||
}
|
||||
|
||||
// If we have images, convert to content block array
|
||||
const blocks = content.map((block) => {
|
||||
if (block.type === "text") {
|
||||
return {
|
||||
type: "text" as const,
|
||||
text: sanitizeSurrogates(block.text),
|
||||
};
|
||||
}
|
||||
return {
|
||||
type: "image" as const,
|
||||
source: {
|
||||
type: "base64" as const,
|
||||
media_type: block.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
|
||||
data: block.data,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
// If only images (no text), add placeholder text block
|
||||
const hasText = blocks.some((b) => b.type === "text");
|
||||
if (!hasText) {
|
||||
blocks.unshift({
|
||||
type: "text" as const,
|
||||
text: "(see attached image)",
|
||||
});
|
||||
}
|
||||
|
||||
return blocks;
|
||||
}
|
||||
|
||||
export type AnthropicEffort = "low" | "medium" | "high" | "max";
|
||||
|
||||
export interface AnthropicOptions extends StreamOptions {
|
||||
/**
|
||||
* Enable extended thinking.
|
||||
* For Opus 4.6 and Sonnet 4.6: uses adaptive thinking (model decides when/how much to think).
|
||||
* For older models: uses budget-based thinking with thinkingBudgetTokens.
|
||||
*/
|
||||
thinkingEnabled?: boolean;
|
||||
/**
|
||||
* Token budget for extended thinking (older models only).
|
||||
* Ignored for Opus 4.6 and Sonnet 4.6, which use adaptive thinking.
|
||||
*/
|
||||
thinkingBudgetTokens?: number;
|
||||
/**
|
||||
* Effort level for adaptive thinking (Opus 4.6 and Sonnet 4.6).
|
||||
* Controls how much thinking Claude allocates:
|
||||
* - "max": Always thinks with no constraints (Opus 4.6 only)
|
||||
* - "high": Always thinks, deep reasoning (default)
|
||||
* - "medium": Moderate thinking, may skip for simple queries
|
||||
* - "low": Minimal thinking, skips for simple tasks
|
||||
* Ignored for older models.
|
||||
*/
|
||||
effort?: AnthropicEffort;
|
||||
interleavedThinking?: boolean;
|
||||
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
|
||||
/**
|
||||
* Pre-built Anthropic client instance. When provided, skips internal client
|
||||
* construction entirely. Use this to inject alternative SDK clients such as
|
||||
* `AnthropicVertex` that shares the same messaging API.
|
||||
*/
|
||||
client?: Anthropic;
|
||||
}
|
||||
|
||||
function mergeHeaders(...headerSources: (Record<string, string> | undefined)[]): Record<string, string> {
|
||||
const merged: Record<string, string> = {};
|
||||
for (const headers of headerSources) {
|
||||
if (headers) {
|
||||
Object.assign(merged, headers);
|
||||
}
|
||||
}
|
||||
return merged;
|
||||
}
|
||||
|
||||
export const streamAnthropic: StreamFunction<"anthropic-messages", AnthropicOptions> = (
|
||||
model: Model<"anthropic-messages">,
|
||||
context: Context,
|
||||
options?: AnthropicOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
let client: Anthropic;
|
||||
let isOAuth: boolean;
|
||||
|
||||
if (options?.client) {
|
||||
client = options.client;
|
||||
isOAuth = false;
|
||||
} else {
|
||||
const apiKey = options?.apiKey ?? getEnvApiKey(model.provider) ?? "";
|
||||
|
||||
let copilotDynamicHeaders: Record<string, string> | undefined;
|
||||
if (model.provider === "github-copilot") {
|
||||
const hasImages = hasCopilotVisionInput(context.messages);
|
||||
copilotDynamicHeaders = buildCopilotDynamicHeaders({
|
||||
messages: context.messages,
|
||||
hasImages,
|
||||
});
|
||||
}
|
||||
|
||||
const created = createClient(
|
||||
model,
|
||||
apiKey,
|
||||
options?.interleavedThinking ?? true,
|
||||
options?.headers,
|
||||
copilotDynamicHeaders,
|
||||
);
|
||||
client = created.client;
|
||||
isOAuth = created.isOAuthToken;
|
||||
}
|
||||
let params = buildParams(model, context, isOAuth, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as MessageCreateParamsStreaming;
|
||||
}
|
||||
const anthropicStream = client.messages.stream({ ...params, stream: true }, { signal: options?.signal });
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
type Block = (ThinkingContent | TextContent | (ToolCall & { partialJson: string })) & { index: number };
|
||||
const blocks = output.content as Block[];
|
||||
|
||||
for await (const event of anthropicStream) {
|
||||
if (event.type === "message_start") {
|
||||
output.responseId = event.message.id;
|
||||
// Capture initial token usage from message_start event
|
||||
// This ensures we have input token counts even if the stream is aborted early
|
||||
output.usage.input = event.message.usage.input_tokens || 0;
|
||||
output.usage.output = event.message.usage.output_tokens || 0;
|
||||
output.usage.cacheRead = event.message.usage.cache_read_input_tokens || 0;
|
||||
output.usage.cacheWrite = event.message.usage.cache_creation_input_tokens || 0;
|
||||
// Anthropic doesn't provide total_tokens, compute from components
|
||||
output.usage.totalTokens =
|
||||
output.usage.input + output.usage.output + output.usage.cacheRead + output.usage.cacheWrite;
|
||||
calculateCost(model, output.usage);
|
||||
} else if (event.type === "content_block_start") {
|
||||
if (event.content_block.type === "text") {
|
||||
const block: Block = {
|
||||
type: "text",
|
||||
text: "",
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({ type: "text_start", contentIndex: output.content.length - 1, partial: output });
|
||||
} else if (event.content_block.type === "thinking") {
|
||||
const block: Block = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: "",
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({ type: "thinking_start", contentIndex: output.content.length - 1, partial: output });
|
||||
} else if (event.content_block.type === "redacted_thinking") {
|
||||
const block: Block = {
|
||||
type: "thinking",
|
||||
thinking: "[Reasoning redacted]",
|
||||
thinkingSignature: event.content_block.data,
|
||||
redacted: true,
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({ type: "thinking_start", contentIndex: output.content.length - 1, partial: output });
|
||||
} else if (event.content_block.type === "tool_use") {
|
||||
const block: Block = {
|
||||
type: "toolCall",
|
||||
id: event.content_block.id,
|
||||
name: isOAuth
|
||||
? fromClaudeCodeName(event.content_block.name, context.tools)
|
||||
: event.content_block.name,
|
||||
arguments: (event.content_block.input as Record<string, any>) ?? {},
|
||||
partialJson: "",
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({ type: "toolcall_start", contentIndex: output.content.length - 1, partial: output });
|
||||
}
|
||||
} else if (event.type === "content_block_delta") {
|
||||
if (event.delta.type === "text_delta") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block && block.type === "text") {
|
||||
block.text += event.delta.text;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: index,
|
||||
delta: event.delta.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "thinking_delta") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block && block.type === "thinking") {
|
||||
block.thinking += event.delta.thinking;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: index,
|
||||
delta: event.delta.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "input_json_delta") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block && block.type === "toolCall") {
|
||||
block.partialJson += event.delta.partial_json;
|
||||
block.arguments = parseStreamingJson(block.partialJson);
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: index,
|
||||
delta: event.delta.partial_json,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "signature_delta") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block && block.type === "thinking") {
|
||||
block.thinkingSignature = block.thinkingSignature || "";
|
||||
block.thinkingSignature += event.delta.signature;
|
||||
}
|
||||
}
|
||||
} else if (event.type === "content_block_stop") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block) {
|
||||
delete (block as any).index;
|
||||
if (block.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: index,
|
||||
content: block.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: index,
|
||||
content: block.thinking,
|
||||
partial: output,
|
||||
});
|
||||
} else if (block.type === "toolCall") {
|
||||
block.arguments = parseStreamingJson(block.partialJson);
|
||||
delete (block as any).partialJson;
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: index,
|
||||
toolCall: block,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "message_delta") {
|
||||
if (event.delta.stop_reason) {
|
||||
output.stopReason = mapStopReason(event.delta.stop_reason);
|
||||
}
|
||||
// Only update usage fields if present (not null).
|
||||
// Preserves input_tokens from message_start when proxies omit it in message_delta.
|
||||
if (event.usage.input_tokens != null) {
|
||||
output.usage.input = event.usage.input_tokens;
|
||||
}
|
||||
if (event.usage.output_tokens != null) {
|
||||
output.usage.output = event.usage.output_tokens;
|
||||
}
|
||||
if (event.usage.cache_read_input_tokens != null) {
|
||||
output.usage.cacheRead = event.usage.cache_read_input_tokens;
|
||||
}
|
||||
if (event.usage.cache_creation_input_tokens != null) {
|
||||
output.usage.cacheWrite = event.usage.cache_creation_input_tokens;
|
||||
}
|
||||
// Anthropic doesn't provide total_tokens, compute from components
|
||||
output.usage.totalTokens =
|
||||
output.usage.input + output.usage.output + output.usage.cacheRead + output.usage.cacheWrite;
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) delete (block as any).index;
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
/**
|
||||
* Check if a model supports adaptive thinking (Opus 4.6 and Sonnet 4.6)
|
||||
*/
|
||||
function supportsAdaptiveThinking(modelId: string): boolean {
|
||||
// Opus 4.6 and Sonnet 4.6 model IDs (with or without date suffix)
|
||||
return (
|
||||
modelId.includes("opus-4-6") ||
|
||||
modelId.includes("opus-4.6") ||
|
||||
modelId.includes("sonnet-4-6") ||
|
||||
modelId.includes("sonnet-4.6")
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Map ThinkingLevel to Anthropic effort levels for adaptive thinking.
|
||||
* Note: effort "max" is only valid on Opus 4.6.
|
||||
*/
|
||||
function mapThinkingLevelToEffort(level: SimpleStreamOptions["reasoning"], modelId: string): AnthropicEffort {
|
||||
switch (level) {
|
||||
case "minimal":
|
||||
return "low";
|
||||
case "low":
|
||||
return "low";
|
||||
case "medium":
|
||||
return "medium";
|
||||
case "high":
|
||||
return "high";
|
||||
case "xhigh":
|
||||
return modelId.includes("opus-4-6") || modelId.includes("opus-4.6") ? "max" : "high";
|
||||
default:
|
||||
return "high";
|
||||
}
|
||||
}
|
||||
|
||||
export const streamSimpleAnthropic: StreamFunction<"anthropic-messages", SimpleStreamOptions> = (
|
||||
model: Model<"anthropic-messages">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
if (!options?.reasoning) {
|
||||
return streamAnthropic(model, context, { ...base, thinkingEnabled: false } satisfies AnthropicOptions);
|
||||
}
|
||||
|
||||
// For Opus 4.6 and Sonnet 4.6: use adaptive thinking with effort level
|
||||
// For older models: use budget-based thinking
|
||||
if (supportsAdaptiveThinking(model.id)) {
|
||||
const effort = mapThinkingLevelToEffort(options.reasoning, model.id);
|
||||
return streamAnthropic(model, context, {
|
||||
...base,
|
||||
thinkingEnabled: true,
|
||||
effort,
|
||||
} satisfies AnthropicOptions);
|
||||
}
|
||||
|
||||
const adjusted = adjustMaxTokensForThinking(
|
||||
base.maxTokens || 0,
|
||||
model.maxTokens,
|
||||
options.reasoning,
|
||||
options.thinkingBudgets,
|
||||
);
|
||||
|
||||
return streamAnthropic(model, context, {
|
||||
...base,
|
||||
maxTokens: adjusted.maxTokens,
|
||||
thinkingEnabled: true,
|
||||
thinkingBudgetTokens: adjusted.thinkingBudget,
|
||||
} satisfies AnthropicOptions);
|
||||
};
|
||||
|
||||
function isOAuthToken(apiKey: string): boolean {
|
||||
return apiKey.includes("sk-ant-oat");
|
||||
}
|
||||
|
||||
function createClient(
|
||||
model: Model<"anthropic-messages">,
|
||||
apiKey: string,
|
||||
interleavedThinking: boolean,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
dynamicHeaders?: Record<string, string>,
|
||||
): { client: Anthropic; isOAuthToken: boolean } {
|
||||
// Adaptive thinking models (Opus 4.6, Sonnet 4.6) have interleaved thinking built-in.
|
||||
// The beta header is deprecated on Opus 4.6 and redundant on Sonnet 4.6, so skip it.
|
||||
const needsInterleavedBeta = interleavedThinking && !supportsAdaptiveThinking(model.id);
|
||||
|
||||
// Copilot: Bearer auth, selective betas (no fine-grained-tool-streaming)
|
||||
if (model.provider === "github-copilot") {
|
||||
const betaFeatures: string[] = [];
|
||||
if (needsInterleavedBeta) {
|
||||
betaFeatures.push("interleaved-thinking-2025-05-14");
|
||||
}
|
||||
|
||||
const client = new Anthropic({
|
||||
apiKey: null,
|
||||
authToken: apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: mergeHeaders(
|
||||
{
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
...(betaFeatures.length > 0 ? { "anthropic-beta": betaFeatures.join(",") } : {}),
|
||||
},
|
||||
model.headers,
|
||||
dynamicHeaders,
|
||||
optionsHeaders,
|
||||
),
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: false };
|
||||
}
|
||||
|
||||
const betaFeatures = ["fine-grained-tool-streaming-2025-05-14"];
|
||||
if (needsInterleavedBeta) {
|
||||
betaFeatures.push("interleaved-thinking-2025-05-14");
|
||||
}
|
||||
|
||||
// OAuth: Bearer auth, Claude Code identity headers
|
||||
if (isOAuthToken(apiKey)) {
|
||||
const client = new Anthropic({
|
||||
apiKey: null,
|
||||
authToken: apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: mergeHeaders(
|
||||
{
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": `claude-code-20250219,oauth-2025-04-20,${betaFeatures.join(",")}`,
|
||||
"user-agent": `claude-cli/${claudeCodeVersion}`,
|
||||
"x-app": "cli",
|
||||
},
|
||||
model.headers,
|
||||
optionsHeaders,
|
||||
),
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: true };
|
||||
}
|
||||
|
||||
// API key auth
|
||||
const client = new Anthropic({
|
||||
apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: mergeHeaders(
|
||||
{
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": betaFeatures.join(","),
|
||||
},
|
||||
model.headers,
|
||||
optionsHeaders,
|
||||
),
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: false };
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"anthropic-messages">,
|
||||
context: Context,
|
||||
isOAuthToken: boolean,
|
||||
options?: AnthropicOptions,
|
||||
): MessageCreateParamsStreaming {
|
||||
const { cacheControl } = getCacheControl(model.baseUrl, options?.cacheRetention);
|
||||
const params: MessageCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
messages: convertMessages(context.messages, model, isOAuthToken, cacheControl),
|
||||
max_tokens: options?.maxTokens || (model.maxTokens / 3) | 0,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// For OAuth tokens, we MUST include Claude Code identity
|
||||
if (isOAuthToken) {
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||
},
|
||||
];
|
||||
if (context.systemPrompt) {
|
||||
params.system.push({
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(context.systemPrompt),
|
||||
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||
});
|
||||
}
|
||||
} else if (context.systemPrompt) {
|
||||
// Add cache control to system prompt for non-OAuth tokens
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(context.systemPrompt),
|
||||
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
// Temperature is incompatible with extended thinking (adaptive or budget-based).
|
||||
if (options?.temperature !== undefined && !options?.thinkingEnabled) {
|
||||
params.temperature = options.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertTools(context.tools, isOAuthToken);
|
||||
}
|
||||
|
||||
// Configure thinking mode: adaptive (Opus 4.6 and Sonnet 4.6),
|
||||
// budget-based (older models), or explicitly disabled.
|
||||
if (model.reasoning) {
|
||||
if (options?.thinkingEnabled) {
|
||||
if (supportsAdaptiveThinking(model.id)) {
|
||||
// Adaptive thinking: Claude decides when and how much to think
|
||||
params.thinking = { type: "adaptive" };
|
||||
if (options.effort) {
|
||||
params.output_config = { effort: options.effort };
|
||||
}
|
||||
} else {
|
||||
// Budget-based thinking for older models
|
||||
params.thinking = {
|
||||
type: "enabled",
|
||||
budget_tokens: options.thinkingBudgetTokens || 1024,
|
||||
};
|
||||
}
|
||||
} else if (options?.thinkingEnabled === false) {
|
||||
params.thinking = { type: "disabled" };
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.metadata) {
|
||||
const userId = options.metadata.user_id;
|
||||
if (typeof userId === "string") {
|
||||
params.metadata = { user_id: userId };
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.toolChoice) {
|
||||
if (typeof options.toolChoice === "string") {
|
||||
params.tool_choice = { type: options.toolChoice };
|
||||
} else {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
// Normalize tool call IDs to match Anthropic's required pattern and length
|
||||
function normalizeToolCallId(id: string): string {
|
||||
return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64);
|
||||
}
|
||||
|
||||
function convertMessages(
|
||||
messages: Message[],
|
||||
model: Model<"anthropic-messages">,
|
||||
isOAuthToken: boolean,
|
||||
cacheControl?: { type: "ephemeral"; ttl?: "1h" },
|
||||
): MessageParam[] {
|
||||
const params: MessageParam[] = [];
|
||||
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, model, normalizeToolCallId);
|
||||
|
||||
for (let i = 0; i < transformedMessages.length; i++) {
|
||||
const msg = transformedMessages[i];
|
||||
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
if (msg.content.trim().length > 0) {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: sanitizeSurrogates(msg.content),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
const blocks: ContentBlockParam[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(item.text),
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
type: "image",
|
||||
source: {
|
||||
type: "base64",
|
||||
media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
|
||||
data: item.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
});
|
||||
let filteredBlocks = !model?.input.includes("image") ? blocks.filter((b) => b.type !== "image") : blocks;
|
||||
filteredBlocks = filteredBlocks.filter((b) => {
|
||||
if (b.type === "text") {
|
||||
return b.text.trim().length > 0;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
if (filteredBlocks.length === 0) continue;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: filteredBlocks,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const blocks: ContentBlockParam[] = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
if (block.text.trim().length === 0) continue;
|
||||
blocks.push({
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(block.text),
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
// Redacted thinking: pass the opaque payload back as redacted_thinking
|
||||
if (block.redacted) {
|
||||
blocks.push({
|
||||
type: "redacted_thinking",
|
||||
data: block.thinkingSignature!,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (block.thinking.trim().length === 0) continue;
|
||||
// If thinking signature is missing/empty (e.g., from aborted stream),
|
||||
// convert to plain text block without <thinking> tags to avoid API rejection
|
||||
// and prevent Claude from mimicking the tags in responses
|
||||
if (!block.thinkingSignature || block.thinkingSignature.trim().length === 0) {
|
||||
blocks.push({
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(block.thinking),
|
||||
});
|
||||
} else {
|
||||
blocks.push({
|
||||
type: "thinking",
|
||||
thinking: sanitizeSurrogates(block.thinking),
|
||||
signature: block.thinkingSignature,
|
||||
});
|
||||
}
|
||||
} else if (block.type === "toolCall") {
|
||||
blocks.push({
|
||||
type: "tool_use",
|
||||
id: block.id,
|
||||
name: isOAuthToken ? toClaudeCodeName(block.name) : block.name,
|
||||
input: block.arguments ?? {},
|
||||
});
|
||||
}
|
||||
}
|
||||
if (blocks.length === 0) continue;
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: blocks,
|
||||
});
|
||||
} else if (msg.role === "toolResult") {
|
||||
// Collect all consecutive toolResult messages, needed for z.ai Anthropic endpoint
|
||||
const toolResults: ContentBlockParam[] = [];
|
||||
|
||||
// Add the current tool result
|
||||
toolResults.push({
|
||||
type: "tool_result",
|
||||
tool_use_id: msg.toolCallId,
|
||||
content: convertContentBlocks(msg.content),
|
||||
is_error: msg.isError,
|
||||
});
|
||||
|
||||
// Look ahead for consecutive toolResult messages
|
||||
let j = i + 1;
|
||||
while (j < transformedMessages.length && transformedMessages[j].role === "toolResult") {
|
||||
const nextMsg = transformedMessages[j] as ToolResultMessage; // We know it's a toolResult
|
||||
toolResults.push({
|
||||
type: "tool_result",
|
||||
tool_use_id: nextMsg.toolCallId,
|
||||
content: convertContentBlocks(nextMsg.content),
|
||||
is_error: nextMsg.isError,
|
||||
});
|
||||
j++;
|
||||
}
|
||||
|
||||
// Skip the messages we've already processed
|
||||
i = j - 1;
|
||||
|
||||
// Add a single user message with all tool results
|
||||
params.push({
|
||||
role: "user",
|
||||
content: toolResults,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Add cache_control to the last user message to cache conversation history
|
||||
if (cacheControl && params.length > 0) {
|
||||
const lastMessage = params[params.length - 1];
|
||||
if (lastMessage.role === "user") {
|
||||
if (Array.isArray(lastMessage.content)) {
|
||||
const lastBlock = lastMessage.content[lastMessage.content.length - 1];
|
||||
if (
|
||||
lastBlock &&
|
||||
(lastBlock.type === "text" || lastBlock.type === "image" || lastBlock.type === "tool_result")
|
||||
) {
|
||||
(lastBlock as any).cache_control = cacheControl;
|
||||
}
|
||||
} else if (typeof lastMessage.content === "string") {
|
||||
lastMessage.content = [
|
||||
{
|
||||
type: "text",
|
||||
text: lastMessage.content,
|
||||
cache_control: cacheControl,
|
||||
},
|
||||
] as any;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
function convertTools(tools: Tool[], isOAuthToken: boolean): Anthropic.Messages.Tool[] {
|
||||
if (!tools) return [];
|
||||
|
||||
return tools.map((tool) => {
|
||||
const jsonSchema = tool.parameters as any; // TypeBox already generates JSON Schema
|
||||
|
||||
return {
|
||||
name: isOAuthToken ? toClaudeCodeName(tool.name) : tool.name,
|
||||
description: tool.description,
|
||||
input_schema: {
|
||||
type: "object" as const,
|
||||
properties: jsonSchema.properties || {},
|
||||
required: jsonSchema.required || [],
|
||||
},
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function mapStopReason(reason: Anthropic.Messages.StopReason | string): StopReason {
|
||||
switch (reason) {
|
||||
case "end_turn":
|
||||
return "stop";
|
||||
case "max_tokens":
|
||||
return "length";
|
||||
case "tool_use":
|
||||
return "toolUse";
|
||||
case "refusal":
|
||||
return "error";
|
||||
case "pause_turn": // Stop is good enough -> resubmit
|
||||
return "stop";
|
||||
case "stop_sequence":
|
||||
return "stop"; // We don't supply stop sequences, so this should never happen
|
||||
case "sensitive": // Content flagged by safety filters (not yet in SDK types)
|
||||
return "error";
|
||||
default:
|
||||
// Handle unknown stop reasons gracefully (API may add new values)
|
||||
throw new Error(`Unhandled stop reason: ${reason}`);
|
||||
}
|
||||
}
|
||||
248
packages/ai/src/providers/azure-openai-responses.ts
Normal file
248
packages/ai/src/providers/azure-openai-responses.ts
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
import { AzureOpenAI } from "openai";
|
||||
import type { ResponseCreateParamsStreaming } from "openai/resources/responses/responses.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { supportsXhigh } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { convertResponsesMessages, convertResponsesTools, processResponsesStream } from "./openai-responses-shared.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
|
||||
const DEFAULT_AZURE_API_VERSION = "v1";
|
||||
const AZURE_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode", "azure-openai-responses"]);
|
||||
|
||||
function parseDeploymentNameMap(value: string | undefined): Map<string, string> {
|
||||
const map = new Map<string, string>();
|
||||
if (!value) return map;
|
||||
for (const entry of value.split(",")) {
|
||||
const trimmed = entry.trim();
|
||||
if (!trimmed) continue;
|
||||
const [modelId, deploymentName] = trimmed.split("=", 2);
|
||||
if (!modelId || !deploymentName) continue;
|
||||
map.set(modelId.trim(), deploymentName.trim());
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
function resolveDeploymentName(model: Model<"azure-openai-responses">, options?: AzureOpenAIResponsesOptions): string {
|
||||
if (options?.azureDeploymentName) {
|
||||
return options.azureDeploymentName;
|
||||
}
|
||||
const mappedDeployment = parseDeploymentNameMap(process.env.AZURE_OPENAI_DEPLOYMENT_NAME_MAP).get(model.id);
|
||||
return mappedDeployment || model.id;
|
||||
}
|
||||
|
||||
// Azure OpenAI Responses-specific options
|
||||
export interface AzureOpenAIResponsesOptions extends StreamOptions {
|
||||
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
reasoningSummary?: "auto" | "detailed" | "concise" | null;
|
||||
azureApiVersion?: string;
|
||||
azureResourceName?: string;
|
||||
azureBaseUrl?: string;
|
||||
azureDeploymentName?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate function for Azure OpenAI Responses API
|
||||
*/
|
||||
export const streamAzureOpenAIResponses: StreamFunction<"azure-openai-responses", AzureOpenAIResponsesOptions> = (
|
||||
model: Model<"azure-openai-responses">,
|
||||
context: Context,
|
||||
options?: AzureOpenAIResponsesOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
// Start async processing
|
||||
(async () => {
|
||||
const deploymentName = resolveDeploymentName(model, options);
|
||||
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "azure-openai-responses" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
// Create Azure OpenAI client
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const client = createClient(model, apiKey, options);
|
||||
let params = buildParams(model, context, options, deploymentName);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as ResponseCreateParamsStreaming;
|
||||
}
|
||||
const openaiStream = await client.responses.create(
|
||||
params,
|
||||
options?.signal ? { signal: options.signal } : undefined,
|
||||
);
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
await processResponsesStream(openaiStream, output, stream, model);
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) delete (block as { index?: number }).index;
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleAzureOpenAIResponses: StreamFunction<"azure-openai-responses", SimpleStreamOptions> = (
|
||||
model: Model<"azure-openai-responses">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const reasoningEffort = supportsXhigh(model) ? options?.reasoning : clampReasoning(options?.reasoning);
|
||||
|
||||
return streamAzureOpenAIResponses(model, context, {
|
||||
...base,
|
||||
reasoningEffort,
|
||||
} satisfies AzureOpenAIResponsesOptions);
|
||||
};
|
||||
|
||||
function normalizeAzureBaseUrl(baseUrl: string): string {
|
||||
return baseUrl.replace(/\/+$/, "");
|
||||
}
|
||||
|
||||
function buildDefaultBaseUrl(resourceName: string): string {
|
||||
return `https://${resourceName}.openai.azure.com/openai/v1`;
|
||||
}
|
||||
|
||||
function resolveAzureConfig(
|
||||
model: Model<"azure-openai-responses">,
|
||||
options?: AzureOpenAIResponsesOptions,
|
||||
): { baseUrl: string; apiVersion: string } {
|
||||
const apiVersion = options?.azureApiVersion || process.env.AZURE_OPENAI_API_VERSION || DEFAULT_AZURE_API_VERSION;
|
||||
|
||||
const baseUrl = options?.azureBaseUrl?.trim() || process.env.AZURE_OPENAI_BASE_URL?.trim() || undefined;
|
||||
const resourceName = options?.azureResourceName || process.env.AZURE_OPENAI_RESOURCE_NAME;
|
||||
|
||||
let resolvedBaseUrl = baseUrl;
|
||||
|
||||
if (!resolvedBaseUrl && resourceName) {
|
||||
resolvedBaseUrl = buildDefaultBaseUrl(resourceName);
|
||||
}
|
||||
|
||||
if (!resolvedBaseUrl && model.baseUrl) {
|
||||
resolvedBaseUrl = model.baseUrl;
|
||||
}
|
||||
|
||||
if (!resolvedBaseUrl) {
|
||||
throw new Error(
|
||||
"Azure OpenAI base URL is required. Set AZURE_OPENAI_BASE_URL or AZURE_OPENAI_RESOURCE_NAME, or pass azureBaseUrl, azureResourceName, or model.baseUrl.",
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
baseUrl: normalizeAzureBaseUrl(resolvedBaseUrl),
|
||||
apiVersion,
|
||||
};
|
||||
}
|
||||
|
||||
function createClient(model: Model<"azure-openai-responses">, apiKey: string, options?: AzureOpenAIResponsesOptions) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.AZURE_OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"Azure OpenAI API key is required. Set AZURE_OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.AZURE_OPENAI_API_KEY;
|
||||
}
|
||||
|
||||
const headers = { ...model.headers };
|
||||
|
||||
if (options?.headers) {
|
||||
Object.assign(headers, options.headers);
|
||||
}
|
||||
|
||||
const { baseUrl, apiVersion } = resolveAzureConfig(model, options);
|
||||
|
||||
return new AzureOpenAI({
|
||||
apiKey,
|
||||
apiVersion,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: headers,
|
||||
baseURL: baseUrl,
|
||||
});
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"azure-openai-responses">,
|
||||
context: Context,
|
||||
options: AzureOpenAIResponsesOptions | undefined,
|
||||
deploymentName: string,
|
||||
) {
|
||||
const messages = convertResponsesMessages(model, context, AZURE_TOOL_CALL_PROVIDERS);
|
||||
|
||||
const params: ResponseCreateParamsStreaming = {
|
||||
model: deploymentName,
|
||||
input: messages,
|
||||
stream: true,
|
||||
prompt_cache_key: options?.sessionId,
|
||||
};
|
||||
|
||||
if (options?.maxTokens) {
|
||||
params.max_output_tokens = options?.maxTokens;
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertResponsesTools(context.tools);
|
||||
}
|
||||
|
||||
if (model.reasoning) {
|
||||
if (options?.reasoningEffort || options?.reasoningSummary) {
|
||||
params.reasoning = {
|
||||
effort: options?.reasoningEffort || "medium",
|
||||
summary: options?.reasoningSummary || "auto",
|
||||
};
|
||||
params.include = ["reasoning.encrypted_content"];
|
||||
} else {
|
||||
params.reasoning = { effort: "none" };
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
37
packages/ai/src/providers/github-copilot-headers.ts
Normal file
37
packages/ai/src/providers/github-copilot-headers.ts
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import type { Message } from "../types.js";
|
||||
|
||||
// Copilot expects X-Initiator to indicate whether the request is user-initiated
|
||||
// or agent-initiated (e.g. follow-up after assistant/tool messages).
|
||||
export function inferCopilotInitiator(messages: Message[]): "user" | "agent" {
|
||||
const last = messages[messages.length - 1];
|
||||
return last && last.role !== "user" ? "agent" : "user";
|
||||
}
|
||||
|
||||
// Copilot requires Copilot-Vision-Request header when sending images
|
||||
export function hasCopilotVisionInput(messages: Message[]): boolean {
|
||||
return messages.some((msg) => {
|
||||
if (msg.role === "user" && Array.isArray(msg.content)) {
|
||||
return msg.content.some((c) => c.type === "image");
|
||||
}
|
||||
if (msg.role === "toolResult" && Array.isArray(msg.content)) {
|
||||
return msg.content.some((c) => c.type === "image");
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
export function buildCopilotDynamicHeaders(params: {
|
||||
messages: Message[];
|
||||
hasImages: boolean;
|
||||
}): Record<string, string> {
|
||||
const headers: Record<string, string> = {
|
||||
"X-Initiator": inferCopilotInitiator(params.messages),
|
||||
"Openai-Intent": "conversation-edits",
|
||||
};
|
||||
|
||||
if (params.hasImages) {
|
||||
headers["Copilot-Vision-Request"] = "true";
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
987
packages/ai/src/providers/google-gemini-cli.ts
Normal file
987
packages/ai/src/providers/google-gemini-cli.ts
Normal file
|
|
@ -0,0 +1,987 @@
|
|||
/**
|
||||
* Google Gemini CLI / Antigravity provider.
|
||||
* Shared implementation for both google-gemini-cli and google-antigravity providers.
|
||||
* Uses the Cloud Code Assist API endpoint to access Gemini and Claude models.
|
||||
*/
|
||||
|
||||
import type { Content, ThinkingConfig } from "@google/genai";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingBudgets,
|
||||
ThinkingContent,
|
||||
ThinkingLevel,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import {
|
||||
convertMessages,
|
||||
convertTools,
|
||||
isThinkingPart,
|
||||
mapStopReasonString,
|
||||
mapToolChoice,
|
||||
retainThoughtSignature,
|
||||
} from "./google-shared.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
|
||||
/**
|
||||
* Thinking level for Gemini 3 models.
|
||||
* Mirrors Google's ThinkingLevel enum values.
|
||||
*/
|
||||
export type GoogleThinkingLevel = "THINKING_LEVEL_UNSPECIFIED" | "MINIMAL" | "LOW" | "MEDIUM" | "HIGH";
|
||||
|
||||
export interface GoogleGeminiCliOptions extends StreamOptions {
|
||||
toolChoice?: "auto" | "none" | "any";
|
||||
/**
|
||||
* Thinking/reasoning configuration.
|
||||
* - Gemini 2.x models: use `budgetTokens` to set the thinking budget
|
||||
* - Gemini 3 models (gemini-3-pro-*, gemini-3-flash-*): use `level` instead
|
||||
*
|
||||
* When using `streamSimple`, this is handled automatically based on the model.
|
||||
*/
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
/** Thinking budget in tokens. Use for Gemini 2.x models. */
|
||||
budgetTokens?: number;
|
||||
/** Thinking level. Use for Gemini 3 models (LOW/HIGH for Pro, MINIMAL/LOW/MEDIUM/HIGH for Flash). */
|
||||
level?: GoogleThinkingLevel;
|
||||
};
|
||||
projectId?: string;
|
||||
}
|
||||
|
||||
const DEFAULT_ENDPOINT = "https://cloudcode-pa.googleapis.com";
|
||||
const ANTIGRAVITY_DAILY_ENDPOINT = "https://daily-cloudcode-pa.sandbox.googleapis.com";
|
||||
const ANTIGRAVITY_AUTOPUSH_ENDPOINT = "https://autopush-cloudcode-pa.sandbox.googleapis.com";
|
||||
const ANTIGRAVITY_ENDPOINT_FALLBACKS = [
|
||||
ANTIGRAVITY_DAILY_ENDPOINT,
|
||||
ANTIGRAVITY_AUTOPUSH_ENDPOINT,
|
||||
DEFAULT_ENDPOINT,
|
||||
] as const;
|
||||
// Headers for Gemini CLI (prod endpoint)
|
||||
const GEMINI_CLI_HEADERS = {
|
||||
"User-Agent": "google-cloud-sdk vscode_cloudshelleditor/0.1",
|
||||
"X-Goog-Api-Client": "gl-node/22.17.0",
|
||||
"Client-Metadata": JSON.stringify({
|
||||
ideType: "IDE_UNSPECIFIED",
|
||||
platform: "PLATFORM_UNSPECIFIED",
|
||||
pluginType: "GEMINI",
|
||||
}),
|
||||
};
|
||||
|
||||
// Headers for Antigravity (sandbox endpoint) - requires specific User-Agent
|
||||
const DEFAULT_ANTIGRAVITY_VERSION = "1.18.4";
|
||||
|
||||
function getAntigravityHeaders() {
|
||||
const version = process.env.PI_AI_ANTIGRAVITY_VERSION || DEFAULT_ANTIGRAVITY_VERSION;
|
||||
return {
|
||||
"User-Agent": `antigravity/${version} darwin/arm64`,
|
||||
};
|
||||
}
|
||||
|
||||
// Antigravity system instruction (compact version from CLIProxyAPI).
|
||||
const ANTIGRAVITY_SYSTEM_INSTRUCTION =
|
||||
"You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding." +
|
||||
"You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question." +
|
||||
"**Absolute paths only**" +
|
||||
"**Proactiveness**";
|
||||
|
||||
// Counter for generating unique tool call IDs
|
||||
let toolCallCounter = 0;
|
||||
|
||||
// Retry configuration
|
||||
const MAX_RETRIES = 3;
|
||||
const BASE_DELAY_MS = 1000;
|
||||
const MAX_EMPTY_STREAM_RETRIES = 2;
|
||||
const EMPTY_STREAM_BASE_DELAY_MS = 500;
|
||||
const CLAUDE_THINKING_BETA_HEADER = "interleaved-thinking-2025-05-14";
|
||||
|
||||
/**
|
||||
* Extract retry delay from Gemini error response (in milliseconds).
|
||||
* Checks headers first (Retry-After, x-ratelimit-reset, x-ratelimit-reset-after),
|
||||
* then parses body patterns like:
|
||||
* - "Your quota will reset after 39s"
|
||||
* - "Your quota will reset after 18h31m10s"
|
||||
* - "Please retry in Xs" or "Please retry in Xms"
|
||||
* - "retryDelay": "34.074824224s" (JSON field)
|
||||
*/
|
||||
export function extractRetryDelay(errorText: string, response?: Response | Headers): number | undefined {
|
||||
const normalizeDelay = (ms: number): number | undefined => (ms > 0 ? Math.ceil(ms + 1000) : undefined);
|
||||
|
||||
const headers = response instanceof Headers ? response : response?.headers;
|
||||
if (headers) {
|
||||
const retryAfter = headers.get("retry-after");
|
||||
if (retryAfter) {
|
||||
const retryAfterSeconds = Number(retryAfter);
|
||||
if (Number.isFinite(retryAfterSeconds)) {
|
||||
const delay = normalizeDelay(retryAfterSeconds * 1000);
|
||||
if (delay !== undefined) {
|
||||
return delay;
|
||||
}
|
||||
}
|
||||
const retryAfterDate = new Date(retryAfter);
|
||||
const retryAfterMs = retryAfterDate.getTime();
|
||||
if (!Number.isNaN(retryAfterMs)) {
|
||||
const delay = normalizeDelay(retryAfterMs - Date.now());
|
||||
if (delay !== undefined) {
|
||||
return delay;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const rateLimitReset = headers.get("x-ratelimit-reset");
|
||||
if (rateLimitReset) {
|
||||
const resetSeconds = Number.parseInt(rateLimitReset, 10);
|
||||
if (!Number.isNaN(resetSeconds)) {
|
||||
const delay = normalizeDelay(resetSeconds * 1000 - Date.now());
|
||||
if (delay !== undefined) {
|
||||
return delay;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const rateLimitResetAfter = headers.get("x-ratelimit-reset-after");
|
||||
if (rateLimitResetAfter) {
|
||||
const resetAfterSeconds = Number(rateLimitResetAfter);
|
||||
if (Number.isFinite(resetAfterSeconds)) {
|
||||
const delay = normalizeDelay(resetAfterSeconds * 1000);
|
||||
if (delay !== undefined) {
|
||||
return delay;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern 1: "Your quota will reset after ..." (formats: "18h31m10s", "10m15s", "6s", "39s")
|
||||
const durationMatch = errorText.match(/reset after (?:(\d+)h)?(?:(\d+)m)?(\d+(?:\.\d+)?)s/i);
|
||||
if (durationMatch) {
|
||||
const hours = durationMatch[1] ? parseInt(durationMatch[1], 10) : 0;
|
||||
const minutes = durationMatch[2] ? parseInt(durationMatch[2], 10) : 0;
|
||||
const seconds = parseFloat(durationMatch[3]);
|
||||
if (!Number.isNaN(seconds)) {
|
||||
const totalMs = ((hours * 60 + minutes) * 60 + seconds) * 1000;
|
||||
const delay = normalizeDelay(totalMs);
|
||||
if (delay !== undefined) {
|
||||
return delay;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern 2: "Please retry in X[ms|s]"
|
||||
const retryInMatch = errorText.match(/Please retry in ([0-9.]+)(ms|s)/i);
|
||||
if (retryInMatch?.[1]) {
|
||||
const value = parseFloat(retryInMatch[1]);
|
||||
if (!Number.isNaN(value) && value > 0) {
|
||||
const ms = retryInMatch[2].toLowerCase() === "ms" ? value : value * 1000;
|
||||
const delay = normalizeDelay(ms);
|
||||
if (delay !== undefined) {
|
||||
return delay;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern 3: "retryDelay": "34.074824224s" (JSON field in error details)
|
||||
const retryDelayMatch = errorText.match(/"retryDelay":\s*"([0-9.]+)(ms|s)"/i);
|
||||
if (retryDelayMatch?.[1]) {
|
||||
const value = parseFloat(retryDelayMatch[1]);
|
||||
if (!Number.isNaN(value) && value > 0) {
|
||||
const ms = retryDelayMatch[2].toLowerCase() === "ms" ? value : value * 1000;
|
||||
const delay = normalizeDelay(ms);
|
||||
if (delay !== undefined) {
|
||||
return delay;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function needsClaudeThinkingBetaHeader(model: Model<"google-gemini-cli">): boolean {
|
||||
return model.provider === "google-antigravity" && model.id.startsWith("claude-") && model.reasoning;
|
||||
}
|
||||
|
||||
function isGemini3ProModel(modelId: string): boolean {
|
||||
return /gemini-3(?:\.1)?-pro/.test(modelId.toLowerCase());
|
||||
}
|
||||
|
||||
function isGemini3FlashModel(modelId: string): boolean {
|
||||
return /gemini-3(?:\.1)?-flash/.test(modelId.toLowerCase());
|
||||
}
|
||||
|
||||
function isGemini3Model(modelId: string): boolean {
|
||||
return isGemini3ProModel(modelId) || isGemini3FlashModel(modelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if an error is retryable (rate limit, server error, network error, etc.)
|
||||
*/
|
||||
function isRetryableError(status: number, errorText: string): boolean {
|
||||
if (status === 429 || status === 500 || status === 502 || status === 503 || status === 504) {
|
||||
return true;
|
||||
}
|
||||
return /resource.?exhausted|rate.?limit|overloaded|service.?unavailable|other.?side.?closed/i.test(errorText);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract a clean, user-friendly error message from Google API error response.
|
||||
* Parses JSON error responses and returns just the message field.
|
||||
*/
|
||||
function extractErrorMessage(errorText: string): string {
|
||||
try {
|
||||
const parsed = JSON.parse(errorText) as { error?: { message?: string } };
|
||||
if (parsed.error?.message) {
|
||||
return parsed.error.message;
|
||||
}
|
||||
} catch {
|
||||
// Not JSON, return as-is
|
||||
}
|
||||
return errorText;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sleep for a given number of milliseconds, respecting abort signal.
|
||||
*/
|
||||
function sleep(ms: number, signal?: AbortSignal): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Request was aborted"));
|
||||
return;
|
||||
}
|
||||
const timeout = setTimeout(resolve, ms);
|
||||
signal?.addEventListener("abort", () => {
|
||||
clearTimeout(timeout);
|
||||
reject(new Error("Request was aborted"));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
interface CloudCodeAssistRequest {
|
||||
project: string;
|
||||
model: string;
|
||||
request: {
|
||||
contents: Content[];
|
||||
sessionId?: string;
|
||||
systemInstruction?: { role?: string; parts: { text: string }[] };
|
||||
generationConfig?: {
|
||||
maxOutputTokens?: number;
|
||||
temperature?: number;
|
||||
thinkingConfig?: ThinkingConfig;
|
||||
};
|
||||
tools?: ReturnType<typeof convertTools>;
|
||||
toolConfig?: {
|
||||
functionCallingConfig: {
|
||||
mode: ReturnType<typeof mapToolChoice>;
|
||||
};
|
||||
};
|
||||
};
|
||||
requestType?: string;
|
||||
userAgent?: string;
|
||||
requestId?: string;
|
||||
}
|
||||
|
||||
interface CloudCodeAssistResponseChunk {
|
||||
response?: {
|
||||
candidates?: Array<{
|
||||
content?: {
|
||||
role: string;
|
||||
parts?: Array<{
|
||||
text?: string;
|
||||
thought?: boolean;
|
||||
thoughtSignature?: string;
|
||||
functionCall?: {
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
id?: string;
|
||||
};
|
||||
}>;
|
||||
};
|
||||
finishReason?: string;
|
||||
}>;
|
||||
usageMetadata?: {
|
||||
promptTokenCount?: number;
|
||||
candidatesTokenCount?: number;
|
||||
thoughtsTokenCount?: number;
|
||||
totalTokenCount?: number;
|
||||
cachedContentTokenCount?: number;
|
||||
};
|
||||
modelVersion?: string;
|
||||
responseId?: string;
|
||||
};
|
||||
traceId?: string;
|
||||
}
|
||||
|
||||
export const streamGoogleGeminiCli: StreamFunction<"google-gemini-cli", GoogleGeminiCliOptions> = (
|
||||
model: Model<"google-gemini-cli">,
|
||||
context: Context,
|
||||
options?: GoogleGeminiCliOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "google-gemini-cli" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
// apiKey is JSON-encoded: { token, projectId }
|
||||
const apiKeyRaw = options?.apiKey;
|
||||
if (!apiKeyRaw) {
|
||||
throw new Error("Google Cloud Code Assist requires OAuth authentication. Use /login to authenticate.");
|
||||
}
|
||||
|
||||
let accessToken: string;
|
||||
let projectId: string;
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(apiKeyRaw) as { token: string; projectId: string };
|
||||
accessToken = parsed.token;
|
||||
projectId = parsed.projectId;
|
||||
} catch {
|
||||
throw new Error("Invalid Google Cloud Code Assist credentials. Use /login to re-authenticate.");
|
||||
}
|
||||
|
||||
if (!accessToken || !projectId) {
|
||||
throw new Error("Missing token or projectId in Google Cloud credentials. Use /login to re-authenticate.");
|
||||
}
|
||||
|
||||
const isAntigravity = model.provider === "google-antigravity";
|
||||
const baseUrl = model.baseUrl?.trim();
|
||||
const endpoints = baseUrl ? [baseUrl] : isAntigravity ? ANTIGRAVITY_ENDPOINT_FALLBACKS : [DEFAULT_ENDPOINT];
|
||||
|
||||
let requestBody = buildRequest(model, context, projectId, options, isAntigravity);
|
||||
const nextRequestBody = await options?.onPayload?.(requestBody, model);
|
||||
if (nextRequestBody !== undefined) {
|
||||
requestBody = nextRequestBody as CloudCodeAssistRequest;
|
||||
}
|
||||
const headers = isAntigravity ? getAntigravityHeaders() : GEMINI_CLI_HEADERS;
|
||||
|
||||
const requestHeaders = {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
Accept: "text/event-stream",
|
||||
...headers,
|
||||
...(needsClaudeThinkingBetaHeader(model) ? { "anthropic-beta": CLAUDE_THINKING_BETA_HEADER } : {}),
|
||||
...options?.headers,
|
||||
};
|
||||
const requestBodyJson = JSON.stringify(requestBody);
|
||||
|
||||
// Fetch with retry logic for rate limits, transient errors, and endpoint fallbacks.
|
||||
// On 403/404, immediately try the next endpoint (no delay).
|
||||
// On 429/5xx, retry with backoff on the same or next endpoint.
|
||||
let response: Response | undefined;
|
||||
let lastError: Error | undefined;
|
||||
let requestUrl: string | undefined;
|
||||
let endpointIndex = 0;
|
||||
|
||||
for (let attempt = 0; attempt <= MAX_RETRIES; attempt++) {
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
try {
|
||||
const endpoint = endpoints[endpointIndex];
|
||||
requestUrl = `${endpoint}/v1internal:streamGenerateContent?alt=sse`;
|
||||
response = await fetch(requestUrl, {
|
||||
method: "POST",
|
||||
headers: requestHeaders,
|
||||
body: requestBodyJson,
|
||||
signal: options?.signal,
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
break; // Success, exit retry loop
|
||||
}
|
||||
|
||||
const errorText = await response.text();
|
||||
|
||||
// On 403/404, cascade to the next endpoint immediately (no delay)
|
||||
if ((response.status === 403 || response.status === 404) && endpointIndex < endpoints.length - 1) {
|
||||
endpointIndex++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if retryable (429, 5xx, network patterns)
|
||||
if (attempt < MAX_RETRIES && isRetryableError(response.status, errorText)) {
|
||||
// Advance endpoint if possible
|
||||
if (endpointIndex < endpoints.length - 1) {
|
||||
endpointIndex++;
|
||||
}
|
||||
|
||||
// Use server-provided delay or exponential backoff
|
||||
const serverDelay = extractRetryDelay(errorText, response);
|
||||
const delayMs = serverDelay ?? BASE_DELAY_MS * 2 ** attempt;
|
||||
|
||||
// Check if server delay exceeds max allowed (default: 60s)
|
||||
const maxDelayMs = options?.maxRetryDelayMs ?? 60000;
|
||||
if (maxDelayMs > 0 && serverDelay && serverDelay > maxDelayMs) {
|
||||
const delaySeconds = Math.ceil(serverDelay / 1000);
|
||||
throw new Error(
|
||||
`Server requested ${delaySeconds}s retry delay (max: ${Math.ceil(maxDelayMs / 1000)}s). ${extractErrorMessage(errorText)}`,
|
||||
);
|
||||
}
|
||||
|
||||
await sleep(delayMs, options?.signal);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Not retryable or max retries exceeded
|
||||
throw new Error(`Cloud Code Assist API error (${response.status}): ${extractErrorMessage(errorText)}`);
|
||||
} catch (error) {
|
||||
// Check for abort - fetch throws AbortError, our code throws "Request was aborted"
|
||||
if (error instanceof Error) {
|
||||
if (error.name === "AbortError" || error.message === "Request was aborted") {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
}
|
||||
// Extract detailed error message from fetch errors (Node includes cause)
|
||||
lastError = error instanceof Error ? error : new Error(String(error));
|
||||
if (lastError.message === "fetch failed" && lastError.cause instanceof Error) {
|
||||
lastError = new Error(`Network error: ${lastError.cause.message}`);
|
||||
}
|
||||
// Network errors are retryable
|
||||
if (attempt < MAX_RETRIES) {
|
||||
const delayMs = BASE_DELAY_MS * 2 ** attempt;
|
||||
await sleep(delayMs, options?.signal);
|
||||
continue;
|
||||
}
|
||||
throw lastError;
|
||||
}
|
||||
}
|
||||
|
||||
if (!response || !response.ok) {
|
||||
throw lastError ?? new Error("Failed to get response after retries");
|
||||
}
|
||||
|
||||
let started = false;
|
||||
const ensureStarted = () => {
|
||||
if (!started) {
|
||||
stream.push({ type: "start", partial: output });
|
||||
started = true;
|
||||
}
|
||||
};
|
||||
|
||||
const resetOutput = () => {
|
||||
output.content = [];
|
||||
output.usage = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
output.stopReason = "stop";
|
||||
output.errorMessage = undefined;
|
||||
output.timestamp = Date.now();
|
||||
started = false;
|
||||
};
|
||||
|
||||
const streamResponse = async (activeResponse: Response): Promise<boolean> => {
|
||||
if (!activeResponse.body) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
let hasContent = false;
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
|
||||
// Read SSE stream
|
||||
const reader = activeResponse.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
// Set up abort handler to cancel reader when signal fires
|
||||
const abortHandler = () => {
|
||||
void reader.cancel().catch(() => {});
|
||||
};
|
||||
options?.signal?.addEventListener("abort", abortHandler);
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
// Check abort signal before each read
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() || "";
|
||||
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith("data:")) continue;
|
||||
|
||||
const jsonStr = line.slice(5).trim();
|
||||
if (!jsonStr) continue;
|
||||
|
||||
let chunk: CloudCodeAssistResponseChunk;
|
||||
try {
|
||||
chunk = JSON.parse(jsonStr);
|
||||
} catch {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Unwrap the response
|
||||
const responseData = chunk.response;
|
||||
if (!responseData) continue;
|
||||
// Cloud Code Assist mirrors Gemini's responseId field. Keep the first non-empty one.
|
||||
// A single streamed response should retain the same ID across chunks.
|
||||
output.responseId ||= responseData.responseId;
|
||||
|
||||
const candidate = responseData.candidates?.[0];
|
||||
if (candidate?.content?.parts) {
|
||||
for (const part of candidate.content.parts) {
|
||||
if (part.text !== undefined) {
|
||||
hasContent = true;
|
||||
const isThinking = isThinkingPart(part);
|
||||
if (
|
||||
!currentBlock ||
|
||||
(isThinking && currentBlock.type !== "thinking") ||
|
||||
(!isThinking && currentBlock.type !== "text")
|
||||
) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blocks.length - 1,
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (isThinking) {
|
||||
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
|
||||
output.content.push(currentBlock);
|
||||
ensureStarted();
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
ensureStarted();
|
||||
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
}
|
||||
if (currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += part.text;
|
||||
currentBlock.thinkingSignature = retainThoughtSignature(
|
||||
currentBlock.thinkingSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock.text += part.text;
|
||||
currentBlock.textSignature = retainThoughtSignature(
|
||||
currentBlock.textSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (part.functionCall) {
|
||||
hasContent = true;
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
|
||||
const providedId = part.functionCall.id;
|
||||
const needsNewId =
|
||||
!providedId ||
|
||||
output.content.some((b) => b.type === "toolCall" && b.id === providedId);
|
||||
const toolCallId = needsNewId
|
||||
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
|
||||
: providedId;
|
||||
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: toolCallId,
|
||||
name: part.functionCall.name || "",
|
||||
arguments: (part.functionCall.args as Record<string, unknown>) ?? {},
|
||||
...(part.thoughtSignature && { thoughtSignature: part.thoughtSignature }),
|
||||
};
|
||||
|
||||
output.content.push(toolCall);
|
||||
ensureStarted();
|
||||
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: JSON.stringify(toolCall.arguments),
|
||||
partial: output,
|
||||
});
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: blockIndex(),
|
||||
toolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (candidate?.finishReason) {
|
||||
output.stopReason = mapStopReasonString(candidate.finishReason);
|
||||
if (output.content.some((b) => b.type === "toolCall")) {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
}
|
||||
|
||||
if (responseData.usageMetadata) {
|
||||
// promptTokenCount includes cachedContentTokenCount, so subtract to get fresh input
|
||||
const promptTokens = responseData.usageMetadata.promptTokenCount || 0;
|
||||
const cacheReadTokens = responseData.usageMetadata.cachedContentTokenCount || 0;
|
||||
output.usage = {
|
||||
input: promptTokens - cacheReadTokens,
|
||||
output:
|
||||
(responseData.usageMetadata.candidatesTokenCount || 0) +
|
||||
(responseData.usageMetadata.thoughtsTokenCount || 0),
|
||||
cacheRead: cacheReadTokens,
|
||||
cacheWrite: 0,
|
||||
totalTokens: responseData.usageMetadata.totalTokenCount || 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
options?.signal?.removeEventListener("abort", abortHandler);
|
||||
}
|
||||
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return hasContent;
|
||||
};
|
||||
|
||||
let receivedContent = false;
|
||||
let currentResponse = response;
|
||||
|
||||
for (let emptyAttempt = 0; emptyAttempt <= MAX_EMPTY_STREAM_RETRIES; emptyAttempt++) {
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (emptyAttempt > 0) {
|
||||
const backoffMs = EMPTY_STREAM_BASE_DELAY_MS * 2 ** (emptyAttempt - 1);
|
||||
await sleep(backoffMs, options?.signal);
|
||||
|
||||
if (!requestUrl) {
|
||||
throw new Error("Missing request URL");
|
||||
}
|
||||
|
||||
currentResponse = await fetch(requestUrl, {
|
||||
method: "POST",
|
||||
headers: requestHeaders,
|
||||
body: requestBodyJson,
|
||||
signal: options?.signal,
|
||||
});
|
||||
|
||||
if (!currentResponse.ok) {
|
||||
const retryErrorText = await currentResponse.text();
|
||||
throw new Error(`Cloud Code Assist API error (${currentResponse.status}): ${retryErrorText}`);
|
||||
}
|
||||
}
|
||||
|
||||
const streamed = await streamResponse(currentResponse);
|
||||
if (streamed) {
|
||||
receivedContent = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (emptyAttempt < MAX_EMPTY_STREAM_RETRIES) {
|
||||
resetOutput();
|
||||
}
|
||||
}
|
||||
|
||||
if (!receivedContent) {
|
||||
throw new Error("Cloud Code Assist API returned an empty response");
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) {
|
||||
if ("index" in block) {
|
||||
delete (block as { index?: number }).index;
|
||||
}
|
||||
}
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleGoogleGeminiCli: StreamFunction<"google-gemini-cli", SimpleStreamOptions> = (
|
||||
model: Model<"google-gemini-cli">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey;
|
||||
if (!apiKey) {
|
||||
throw new Error("Google Cloud Code Assist requires OAuth authentication. Use /login to authenticate.");
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
if (!options?.reasoning) {
|
||||
return streamGoogleGeminiCli(model, context, {
|
||||
...base,
|
||||
thinking: { enabled: false },
|
||||
} satisfies GoogleGeminiCliOptions);
|
||||
}
|
||||
|
||||
const effort = clampReasoning(options.reasoning)!;
|
||||
if (isGemini3Model(model.id)) {
|
||||
return streamGoogleGeminiCli(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
level: getGeminiCliThinkingLevel(effort, model.id),
|
||||
},
|
||||
} satisfies GoogleGeminiCliOptions);
|
||||
}
|
||||
|
||||
const defaultBudgets: ThinkingBudgets = {
|
||||
minimal: 1024,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 16384,
|
||||
};
|
||||
const budgets = { ...defaultBudgets, ...options.thinkingBudgets };
|
||||
|
||||
const minOutputTokens = 1024;
|
||||
let thinkingBudget = budgets[effort]!;
|
||||
const maxTokens = Math.min((base.maxTokens || 0) + thinkingBudget, model.maxTokens);
|
||||
|
||||
if (maxTokens <= thinkingBudget) {
|
||||
thinkingBudget = Math.max(0, maxTokens - minOutputTokens);
|
||||
}
|
||||
|
||||
return streamGoogleGeminiCli(model, context, {
|
||||
...base,
|
||||
maxTokens,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: thinkingBudget,
|
||||
},
|
||||
} satisfies GoogleGeminiCliOptions);
|
||||
};
|
||||
|
||||
export function buildRequest(
|
||||
model: Model<"google-gemini-cli">,
|
||||
context: Context,
|
||||
projectId: string,
|
||||
options: GoogleGeminiCliOptions = {},
|
||||
isAntigravity = false,
|
||||
): CloudCodeAssistRequest {
|
||||
const contents = convertMessages(model, context);
|
||||
|
||||
const generationConfig: CloudCodeAssistRequest["request"]["generationConfig"] = {};
|
||||
if (options.temperature !== undefined) {
|
||||
generationConfig.temperature = options.temperature;
|
||||
}
|
||||
if (options.maxTokens !== undefined) {
|
||||
generationConfig.maxOutputTokens = options.maxTokens;
|
||||
}
|
||||
|
||||
// Thinking config
|
||||
if (options.thinking?.enabled && model.reasoning) {
|
||||
generationConfig.thinkingConfig = {
|
||||
includeThoughts: true,
|
||||
};
|
||||
// Gemini 3 models use thinkingLevel, older models use thinkingBudget
|
||||
if (options.thinking.level !== undefined) {
|
||||
// Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values
|
||||
generationConfig.thinkingConfig.thinkingLevel = options.thinking.level as any;
|
||||
} else if (options.thinking.budgetTokens !== undefined) {
|
||||
generationConfig.thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
|
||||
}
|
||||
} else if (model.reasoning && options.thinking && !options.thinking.enabled) {
|
||||
generationConfig.thinkingConfig = getDisabledThinkingConfig(model.id);
|
||||
}
|
||||
|
||||
const request: CloudCodeAssistRequest["request"] = {
|
||||
contents,
|
||||
};
|
||||
|
||||
request.sessionId = options.sessionId;
|
||||
|
||||
// System instruction must be object with parts, not plain string
|
||||
if (context.systemPrompt) {
|
||||
request.systemInstruction = {
|
||||
parts: [{ text: sanitizeSurrogates(context.systemPrompt) }],
|
||||
};
|
||||
}
|
||||
|
||||
if (Object.keys(generationConfig).length > 0) {
|
||||
request.generationConfig = generationConfig;
|
||||
}
|
||||
|
||||
if (context.tools && context.tools.length > 0) {
|
||||
// Claude models on Cloud Code Assist need the legacy `parameters` field;
|
||||
// the API translates it into Anthropic's `input_schema`.
|
||||
const useParameters = model.id.startsWith("claude-");
|
||||
request.tools = convertTools(context.tools, useParameters);
|
||||
if (options.toolChoice) {
|
||||
request.toolConfig = {
|
||||
functionCallingConfig: {
|
||||
mode: mapToolChoice(options.toolChoice),
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (isAntigravity) {
|
||||
const existingParts = request.systemInstruction?.parts ?? [];
|
||||
request.systemInstruction = {
|
||||
role: "user",
|
||||
parts: [
|
||||
{ text: ANTIGRAVITY_SYSTEM_INSTRUCTION },
|
||||
{ text: `Please ignore following [ignore]${ANTIGRAVITY_SYSTEM_INSTRUCTION}[/ignore]` },
|
||||
...existingParts,
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
project: projectId,
|
||||
model: model.id,
|
||||
request,
|
||||
...(isAntigravity ? { requestType: "agent" } : {}),
|
||||
userAgent: isAntigravity ? "antigravity" : "pi-coding-agent",
|
||||
requestId: `${isAntigravity ? "agent" : "pi"}-${Date.now()}-${Math.random().toString(36).slice(2, 11)}`,
|
||||
};
|
||||
}
|
||||
|
||||
type ClampedThinkingLevel = Exclude<ThinkingLevel, "xhigh">;
|
||||
|
||||
function getDisabledThinkingConfig(modelId: string): ThinkingConfig {
|
||||
// Google docs: Gemini 3.1 Pro cannot disable thinking, and Gemini 3 Flash / Flash-Lite
|
||||
// do not support full thinking-off either. For Gemini 3 models, use the lowest supported
|
||||
// thinkingLevel without includeThoughts so hidden thinking remains invisible to pi.
|
||||
if (isGemini3ProModel(modelId)) {
|
||||
return { thinkingLevel: "LOW" as any };
|
||||
}
|
||||
if (isGemini3FlashModel(modelId)) {
|
||||
return { thinkingLevel: "MINIMAL" as any };
|
||||
}
|
||||
|
||||
// Gemini 2.x supports disabling via thinkingBudget = 0.
|
||||
return { thinkingBudget: 0 };
|
||||
}
|
||||
|
||||
function getGeminiCliThinkingLevel(effort: ClampedThinkingLevel, modelId: string): GoogleThinkingLevel {
|
||||
if (isGemini3ProModel(modelId)) {
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
return "MINIMAL";
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
return "MEDIUM";
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
326
packages/ai/src/providers/google-shared.ts
Normal file
326
packages/ai/src/providers/google-shared.ts
Normal file
|
|
@ -0,0 +1,326 @@
|
|||
/**
|
||||
* Shared utilities for Google Generative AI and Google Cloud Code Assist providers.
|
||||
*/
|
||||
|
||||
import { type Content, FinishReason, FunctionCallingConfigMode, type Part } from "@google/genai";
|
||||
import type { Context, ImageContent, Model, StopReason, TextContent, Tool } from "../types.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
type GoogleApiType = "google-generative-ai" | "google-gemini-cli" | "google-vertex";
|
||||
|
||||
/**
|
||||
* Determines whether a streamed Gemini `Part` should be treated as "thinking".
|
||||
*
|
||||
* Protocol note (Gemini / Vertex AI thought signatures):
|
||||
* - `thought: true` is the definitive marker for thinking content (thought summaries).
|
||||
* - `thoughtSignature` is an encrypted representation of the model's internal thought process
|
||||
* used to preserve reasoning context across multi-turn interactions.
|
||||
* - `thoughtSignature` can appear on ANY part type (text, functionCall, etc.) - it does NOT
|
||||
* indicate the part itself is thinking content.
|
||||
* - For non-functionCall responses, the signature appears on the last part for context replay.
|
||||
* - When persisting/replaying model outputs, signature-bearing parts must be preserved as-is;
|
||||
* do not merge/move signatures across parts.
|
||||
*
|
||||
* See: https://ai.google.dev/gemini-api/docs/thought-signatures
|
||||
*/
|
||||
export function isThinkingPart(part: Pick<Part, "thought" | "thoughtSignature">): boolean {
|
||||
return part.thought === true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retain thought signatures during streaming.
|
||||
*
|
||||
* Some backends only send `thoughtSignature` on the first delta for a given part/block; later deltas may omit it.
|
||||
* This helper preserves the last non-empty signature for the current block.
|
||||
*
|
||||
* Note: this does NOT merge or move signatures across distinct response parts. It only prevents
|
||||
* a signature from being overwritten with `undefined` within the same streamed block.
|
||||
*/
|
||||
export function retainThoughtSignature(existing: string | undefined, incoming: string | undefined): string | undefined {
|
||||
if (typeof incoming === "string" && incoming.length > 0) return incoming;
|
||||
return existing;
|
||||
}
|
||||
|
||||
// Thought signatures must be base64 for Google APIs (TYPE_BYTES).
|
||||
const base64SignaturePattern = /^[A-Za-z0-9+/]+={0,2}$/;
|
||||
|
||||
// Sentinel value that tells the Gemini API to skip thought signature validation.
|
||||
// Used for unsigned function call parts (e.g. replayed from providers without thought signatures).
|
||||
// See: https://ai.google.dev/gemini-api/docs/thought-signatures
|
||||
const SKIP_THOUGHT_SIGNATURE = "skip_thought_signature_validator";
|
||||
|
||||
function isValidThoughtSignature(signature: string | undefined): boolean {
|
||||
if (!signature) return false;
|
||||
if (signature.length % 4 !== 0) return false;
|
||||
return base64SignaturePattern.test(signature);
|
||||
}
|
||||
|
||||
/**
|
||||
* Only keep signatures from the same provider/model and with valid base64.
|
||||
*/
|
||||
function resolveThoughtSignature(isSameProviderAndModel: boolean, signature: string | undefined): string | undefined {
|
||||
return isSameProviderAndModel && isValidThoughtSignature(signature) ? signature : undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Models via Google APIs that require explicit tool call IDs in function calls/responses.
|
||||
*/
|
||||
export function requiresToolCallId(modelId: string): boolean {
|
||||
return modelId.startsWith("claude-") || modelId.startsWith("gpt-oss-");
|
||||
}
|
||||
|
||||
function getGeminiMajorVersion(modelId: string): number | undefined {
|
||||
const match = modelId.toLowerCase().match(/^gemini(?:-live)?-(\d+)/);
|
||||
if (!match) return undefined;
|
||||
return Number.parseInt(match[1], 10);
|
||||
}
|
||||
|
||||
function supportsMultimodalFunctionResponse(modelId: string): boolean {
|
||||
const geminiMajorVersion = getGeminiMajorVersion(modelId);
|
||||
if (geminiMajorVersion !== undefined) {
|
||||
return geminiMajorVersion >= 3;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert internal messages to Gemini Content[] format.
|
||||
*/
|
||||
export function convertMessages<T extends GoogleApiType>(model: Model<T>, context: Context): Content[] {
|
||||
const contents: Content[] = [];
|
||||
const normalizeToolCallId = (id: string): string => {
|
||||
if (!requiresToolCallId(model.id)) return id;
|
||||
return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64);
|
||||
};
|
||||
|
||||
const transformedMessages = transformMessages(context.messages, model, normalizeToolCallId);
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [{ text: sanitizeSurrogates(msg.content) }],
|
||||
});
|
||||
} else {
|
||||
const parts: Part[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return { text: sanitizeSurrogates(item.text) };
|
||||
} else {
|
||||
return {
|
||||
inlineData: {
|
||||
mimeType: item.mimeType,
|
||||
data: item.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
});
|
||||
const filteredParts = !model.input.includes("image") ? parts.filter((p) => p.text !== undefined) : parts;
|
||||
if (filteredParts.length === 0) continue;
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: filteredParts,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const parts: Part[] = [];
|
||||
// Check if message is from same provider and model - only then keep thinking blocks
|
||||
const isSameProviderAndModel = msg.provider === model.provider && msg.model === model.id;
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
// Skip empty text blocks - they can cause issues with some models (e.g. Claude via Antigravity)
|
||||
if (!block.text || block.text.trim() === "") continue;
|
||||
const thoughtSignature = resolveThoughtSignature(isSameProviderAndModel, block.textSignature);
|
||||
parts.push({
|
||||
text: sanitizeSurrogates(block.text),
|
||||
...(thoughtSignature && { thoughtSignature }),
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
// Skip empty thinking blocks
|
||||
if (!block.thinking || block.thinking.trim() === "") continue;
|
||||
// Only keep as thinking block if same provider AND same model
|
||||
// Otherwise convert to plain text (no tags to avoid model mimicking them)
|
||||
if (isSameProviderAndModel) {
|
||||
const thoughtSignature = resolveThoughtSignature(isSameProviderAndModel, block.thinkingSignature);
|
||||
parts.push({
|
||||
thought: true,
|
||||
text: sanitizeSurrogates(block.thinking),
|
||||
...(thoughtSignature && { thoughtSignature }),
|
||||
});
|
||||
} else {
|
||||
parts.push({
|
||||
text: sanitizeSurrogates(block.thinking),
|
||||
});
|
||||
}
|
||||
} else if (block.type === "toolCall") {
|
||||
const thoughtSignature = resolveThoughtSignature(isSameProviderAndModel, block.thoughtSignature);
|
||||
// Gemini 3 requires thoughtSignature on all function calls when thinking mode is enabled.
|
||||
// Use the skip_thought_signature_validator sentinel for unsigned function calls
|
||||
// (e.g. replayed from providers without thought signatures like Claude via Antigravity).
|
||||
const isGemini3 = model.id.toLowerCase().includes("gemini-3");
|
||||
const effectiveSignature = thoughtSignature || (isGemini3 ? SKIP_THOUGHT_SIGNATURE : undefined);
|
||||
const part: Part = {
|
||||
functionCall: {
|
||||
name: block.name,
|
||||
args: block.arguments ?? {},
|
||||
...(requiresToolCallId(model.id) ? { id: block.id } : {}),
|
||||
},
|
||||
...(effectiveSignature && { thoughtSignature: effectiveSignature }),
|
||||
};
|
||||
parts.push(part);
|
||||
}
|
||||
}
|
||||
|
||||
if (parts.length === 0) continue;
|
||||
contents.push({
|
||||
role: "model",
|
||||
parts,
|
||||
});
|
||||
} else if (msg.role === "toolResult") {
|
||||
// Extract text and image content
|
||||
const textContent = msg.content.filter((c): c is TextContent => c.type === "text");
|
||||
const textResult = textContent.map((c) => c.text).join("\n");
|
||||
const imageContent = model.input.includes("image")
|
||||
? msg.content.filter((c): c is ImageContent => c.type === "image")
|
||||
: [];
|
||||
|
||||
const hasText = textResult.length > 0;
|
||||
const hasImages = imageContent.length > 0;
|
||||
|
||||
// Gemini 3+ models support multimodal function responses with images nested inside
|
||||
// functionResponse.parts. Claude and other non-Gemini models behind Cloud Code Assist /
|
||||
// Antigravity also accept this shape. Gemini < 3 still needs a separate user image turn.
|
||||
const modelSupportsMultimodalFunctionResponse = supportsMultimodalFunctionResponse(model.id);
|
||||
|
||||
// Use "output" key for success, "error" key for errors as per SDK documentation
|
||||
const responseValue = hasText ? sanitizeSurrogates(textResult) : hasImages ? "(see attached image)" : "";
|
||||
|
||||
const imageParts: Part[] = imageContent.map((imageBlock) => ({
|
||||
inlineData: {
|
||||
mimeType: imageBlock.mimeType,
|
||||
data: imageBlock.data,
|
||||
},
|
||||
}));
|
||||
|
||||
const includeId = requiresToolCallId(model.id);
|
||||
const functionResponsePart: Part = {
|
||||
functionResponse: {
|
||||
name: msg.toolName,
|
||||
response: msg.isError ? { error: responseValue } : { output: responseValue },
|
||||
...(hasImages && modelSupportsMultimodalFunctionResponse && { parts: imageParts }),
|
||||
...(includeId ? { id: msg.toolCallId } : {}),
|
||||
},
|
||||
};
|
||||
|
||||
// Cloud Code Assist API requires all function responses to be in a single user turn.
|
||||
// Check if the last content is already a user turn with function responses and merge.
|
||||
const lastContent = contents[contents.length - 1];
|
||||
if (lastContent?.role === "user" && lastContent.parts?.some((p) => p.functionResponse)) {
|
||||
lastContent.parts.push(functionResponsePart);
|
||||
} else {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [functionResponsePart],
|
||||
});
|
||||
}
|
||||
|
||||
// For Gemini < 3, add images in a separate user message
|
||||
if (hasImages && !modelSupportsMultimodalFunctionResponse) {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [{ text: "Tool result image:" }, ...imageParts],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return contents;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert tools to Gemini function declarations format.
|
||||
*
|
||||
* By default uses `parametersJsonSchema` which supports full JSON Schema (including
|
||||
* anyOf, oneOf, const, etc.). Set `useParameters` to true to use the legacy `parameters`
|
||||
* field instead (OpenAPI 3.03 Schema). This is needed for Cloud Code Assist with Claude
|
||||
* models, where the API translates `parameters` into Anthropic's `input_schema`.
|
||||
*/
|
||||
export function convertTools(
|
||||
tools: Tool[],
|
||||
useParameters = false,
|
||||
): { functionDeclarations: Record<string, unknown>[] }[] | undefined {
|
||||
if (tools.length === 0) return undefined;
|
||||
return [
|
||||
{
|
||||
functionDeclarations: tools.map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
...(useParameters ? { parameters: tool.parameters } : { parametersJsonSchema: tool.parameters }),
|
||||
})),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
* Map tool choice string to Gemini FunctionCallingConfigMode.
|
||||
*/
|
||||
export function mapToolChoice(choice: string): FunctionCallingConfigMode {
|
||||
switch (choice) {
|
||||
case "auto":
|
||||
return FunctionCallingConfigMode.AUTO;
|
||||
case "none":
|
||||
return FunctionCallingConfigMode.NONE;
|
||||
case "any":
|
||||
return FunctionCallingConfigMode.ANY;
|
||||
default:
|
||||
return FunctionCallingConfigMode.AUTO;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Map Gemini FinishReason to our StopReason.
|
||||
*/
|
||||
export function mapStopReason(reason: FinishReason): StopReason {
|
||||
switch (reason) {
|
||||
case FinishReason.STOP:
|
||||
return "stop";
|
||||
case FinishReason.MAX_TOKENS:
|
||||
return "length";
|
||||
case FinishReason.BLOCKLIST:
|
||||
case FinishReason.PROHIBITED_CONTENT:
|
||||
case FinishReason.SPII:
|
||||
case FinishReason.SAFETY:
|
||||
case FinishReason.IMAGE_SAFETY:
|
||||
case FinishReason.IMAGE_PROHIBITED_CONTENT:
|
||||
case FinishReason.IMAGE_RECITATION:
|
||||
case FinishReason.IMAGE_OTHER:
|
||||
case FinishReason.RECITATION:
|
||||
case FinishReason.FINISH_REASON_UNSPECIFIED:
|
||||
case FinishReason.OTHER:
|
||||
case FinishReason.LANGUAGE:
|
||||
case FinishReason.MALFORMED_FUNCTION_CALL:
|
||||
case FinishReason.UNEXPECTED_TOOL_CALL:
|
||||
case FinishReason.NO_IMAGE:
|
||||
return "error";
|
||||
default: {
|
||||
const _exhaustive: never = reason;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Map string finish reason to our StopReason (for raw API responses).
|
||||
*/
|
||||
export function mapStopReasonString(reason: string): StopReason {
|
||||
switch (reason) {
|
||||
case "STOP":
|
||||
return "stop";
|
||||
case "MAX_TOKENS":
|
||||
return "length";
|
||||
default:
|
||||
return "error";
|
||||
}
|
||||
}
|
||||
541
packages/ai/src/providers/google-vertex.ts
Normal file
541
packages/ai/src/providers/google-vertex.ts
Normal file
|
|
@ -0,0 +1,541 @@
|
|||
import {
|
||||
type GenerateContentConfig,
|
||||
type GenerateContentParameters,
|
||||
GoogleGenAI,
|
||||
type ThinkingConfig,
|
||||
ThinkingLevel,
|
||||
} from "@google/genai";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
ThinkingLevel as PiThinkingLevel,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingBudgets,
|
||||
ThinkingContent,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
|
||||
import {
|
||||
convertMessages,
|
||||
convertTools,
|
||||
isThinkingPart,
|
||||
mapStopReason,
|
||||
mapToolChoice,
|
||||
retainThoughtSignature,
|
||||
} from "./google-shared.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
|
||||
export interface GoogleVertexOptions extends StreamOptions {
|
||||
toolChoice?: "auto" | "none" | "any";
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
budgetTokens?: number; // -1 for dynamic, 0 to disable
|
||||
level?: GoogleThinkingLevel;
|
||||
};
|
||||
project?: string;
|
||||
location?: string;
|
||||
}
|
||||
|
||||
const API_VERSION = "v1";
|
||||
|
||||
const THINKING_LEVEL_MAP: Record<GoogleThinkingLevel, ThinkingLevel> = {
|
||||
THINKING_LEVEL_UNSPECIFIED: ThinkingLevel.THINKING_LEVEL_UNSPECIFIED,
|
||||
MINIMAL: ThinkingLevel.MINIMAL,
|
||||
LOW: ThinkingLevel.LOW,
|
||||
MEDIUM: ThinkingLevel.MEDIUM,
|
||||
HIGH: ThinkingLevel.HIGH,
|
||||
};
|
||||
|
||||
// Counter for generating unique tool call IDs
|
||||
let toolCallCounter = 0;
|
||||
|
||||
export const streamGoogleVertex: StreamFunction<"google-vertex", GoogleVertexOptions> = (
|
||||
model: Model<"google-vertex">,
|
||||
context: Context,
|
||||
options?: GoogleVertexOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "google-vertex" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
const apiKey = resolveApiKey(options);
|
||||
// Create the client using either a Vertex API key, if provided, or ADC with project and location
|
||||
const client = apiKey
|
||||
? createClientWithApiKey(model, apiKey, options?.headers)
|
||||
: createClient(model, resolveProject(options), resolveLocation(options), options?.headers);
|
||||
let params = buildParams(model, context, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as GenerateContentParameters;
|
||||
}
|
||||
const googleStream = await client.models.generateContentStream(params);
|
||||
|
||||
stream.push({ type: "start", partial: output });
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
for await (const chunk of googleStream) {
|
||||
// Vertex uses the same @google/genai GenerateContentResponse type as Gemini.
|
||||
// responseId is documented there as an output-only identifier for each response.
|
||||
output.responseId ||= chunk.responseId;
|
||||
const candidate = chunk.candidates?.[0];
|
||||
if (candidate?.content?.parts) {
|
||||
for (const part of candidate.content.parts) {
|
||||
if (part.text !== undefined) {
|
||||
const isThinking = isThinkingPart(part);
|
||||
if (
|
||||
!currentBlock ||
|
||||
(isThinking && currentBlock.type !== "thinking") ||
|
||||
(!isThinking && currentBlock.type !== "text")
|
||||
) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blocks.length - 1,
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (isThinking) {
|
||||
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
|
||||
} else {
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
}
|
||||
if (currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += part.text;
|
||||
currentBlock.thinkingSignature = retainThoughtSignature(
|
||||
currentBlock.thinkingSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock.text += part.text;
|
||||
currentBlock.textSignature = retainThoughtSignature(
|
||||
currentBlock.textSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (part.functionCall) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
|
||||
const providedId = part.functionCall.id;
|
||||
const needsNewId =
|
||||
!providedId || output.content.some((b) => b.type === "toolCall" && b.id === providedId);
|
||||
const toolCallId = needsNewId
|
||||
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
|
||||
: providedId;
|
||||
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: toolCallId,
|
||||
name: part.functionCall.name || "",
|
||||
arguments: (part.functionCall.args as Record<string, any>) ?? {},
|
||||
...(part.thoughtSignature && { thoughtSignature: part.thoughtSignature }),
|
||||
};
|
||||
|
||||
output.content.push(toolCall);
|
||||
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: JSON.stringify(toolCall.arguments),
|
||||
partial: output,
|
||||
});
|
||||
stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (candidate?.finishReason) {
|
||||
output.stopReason = mapStopReason(candidate.finishReason);
|
||||
if (output.content.some((b) => b.type === "toolCall")) {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
}
|
||||
|
||||
if (chunk.usageMetadata) {
|
||||
output.usage = {
|
||||
input: chunk.usageMetadata.promptTokenCount || 0,
|
||||
output:
|
||||
(chunk.usageMetadata.candidatesTokenCount || 0) + (chunk.usageMetadata.thoughtsTokenCount || 0),
|
||||
cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
// Remove internal index property used during streaming
|
||||
for (const block of output.content) {
|
||||
if ("index" in block) {
|
||||
delete (block as { index?: number }).index;
|
||||
}
|
||||
}
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleGoogleVertex: StreamFunction<"google-vertex", SimpleStreamOptions> = (
|
||||
model: Model<"google-vertex">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const base = buildBaseOptions(model, options, undefined);
|
||||
if (!options?.reasoning) {
|
||||
return streamGoogleVertex(model, context, {
|
||||
...base,
|
||||
thinking: { enabled: false },
|
||||
} satisfies GoogleVertexOptions);
|
||||
}
|
||||
|
||||
const effort = clampReasoning(options.reasoning)!;
|
||||
const geminiModel = model as unknown as Model<"google-generative-ai">;
|
||||
|
||||
if (isGemini3ProModel(geminiModel) || isGemini3FlashModel(geminiModel)) {
|
||||
return streamGoogleVertex(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
level: getGemini3ThinkingLevel(effort, geminiModel),
|
||||
},
|
||||
} satisfies GoogleVertexOptions);
|
||||
}
|
||||
|
||||
return streamGoogleVertex(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: getGoogleBudget(geminiModel, effort, options.thinkingBudgets),
|
||||
},
|
||||
} satisfies GoogleVertexOptions);
|
||||
};
|
||||
|
||||
function createClient(
|
||||
model: Model<"google-vertex">,
|
||||
project: string,
|
||||
location: string,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
): GoogleGenAI {
|
||||
const httpOptions: { headers?: Record<string, string> } = {};
|
||||
|
||||
if (model.headers || optionsHeaders) {
|
||||
httpOptions.headers = { ...model.headers, ...optionsHeaders };
|
||||
}
|
||||
|
||||
const hasHttpOptions = Object.values(httpOptions).some(Boolean);
|
||||
|
||||
return new GoogleGenAI({
|
||||
vertexai: true,
|
||||
project,
|
||||
location,
|
||||
apiVersion: API_VERSION,
|
||||
httpOptions: hasHttpOptions ? httpOptions : undefined,
|
||||
});
|
||||
}
|
||||
|
||||
function createClientWithApiKey(
|
||||
model: Model<"google-vertex">,
|
||||
apiKey: string,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
): GoogleGenAI {
|
||||
const httpOptions: { headers?: Record<string, string> } = {};
|
||||
|
||||
if (model.headers || optionsHeaders) {
|
||||
httpOptions.headers = { ...model.headers, ...optionsHeaders };
|
||||
}
|
||||
|
||||
const hasHttpOptions = Object.values(httpOptions).some(Boolean);
|
||||
|
||||
return new GoogleGenAI({
|
||||
vertexai: true,
|
||||
apiKey,
|
||||
apiVersion: API_VERSION,
|
||||
httpOptions: hasHttpOptions ? httpOptions : undefined,
|
||||
});
|
||||
}
|
||||
|
||||
function resolveApiKey(options?: GoogleVertexOptions): string | undefined {
|
||||
const apiKey = options?.apiKey?.trim() || process.env.GOOGLE_CLOUD_API_KEY?.trim();
|
||||
if (!apiKey || isPlaceholderApiKey(apiKey)) {
|
||||
return undefined;
|
||||
}
|
||||
return apiKey;
|
||||
}
|
||||
|
||||
function isPlaceholderApiKey(apiKey: string): boolean {
|
||||
return /^<[^>]+>$/.test(apiKey);
|
||||
}
|
||||
|
||||
function resolveProject(options?: GoogleVertexOptions): string {
|
||||
const project = options?.project || process.env.GOOGLE_CLOUD_PROJECT || process.env.GCLOUD_PROJECT;
|
||||
if (!project) {
|
||||
throw new Error(
|
||||
"Vertex AI requires a project ID. Set GOOGLE_CLOUD_PROJECT/GCLOUD_PROJECT or pass project in options.",
|
||||
);
|
||||
}
|
||||
return project;
|
||||
}
|
||||
|
||||
function resolveLocation(options?: GoogleVertexOptions): string {
|
||||
const location = options?.location || process.env.GOOGLE_CLOUD_LOCATION;
|
||||
if (!location) {
|
||||
throw new Error("Vertex AI requires a location. Set GOOGLE_CLOUD_LOCATION or pass location in options.");
|
||||
}
|
||||
return location;
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"google-vertex">,
|
||||
context: Context,
|
||||
options: GoogleVertexOptions = {},
|
||||
): GenerateContentParameters {
|
||||
const contents = convertMessages(model, context);
|
||||
|
||||
const generationConfig: GenerateContentConfig = {};
|
||||
if (options.temperature !== undefined) {
|
||||
generationConfig.temperature = options.temperature;
|
||||
}
|
||||
if (options.maxTokens !== undefined) {
|
||||
generationConfig.maxOutputTokens = options.maxTokens;
|
||||
}
|
||||
|
||||
const config: GenerateContentConfig = {
|
||||
...(Object.keys(generationConfig).length > 0 && generationConfig),
|
||||
...(context.systemPrompt && { systemInstruction: sanitizeSurrogates(context.systemPrompt) }),
|
||||
...(context.tools && context.tools.length > 0 && { tools: convertTools(context.tools) }),
|
||||
};
|
||||
|
||||
if (context.tools && context.tools.length > 0 && options.toolChoice) {
|
||||
config.toolConfig = {
|
||||
functionCallingConfig: {
|
||||
mode: mapToolChoice(options.toolChoice),
|
||||
},
|
||||
};
|
||||
} else {
|
||||
config.toolConfig = undefined;
|
||||
}
|
||||
|
||||
if (options.thinking?.enabled && model.reasoning) {
|
||||
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
|
||||
if (options.thinking.level !== undefined) {
|
||||
thinkingConfig.thinkingLevel = THINKING_LEVEL_MAP[options.thinking.level];
|
||||
} else if (options.thinking.budgetTokens !== undefined) {
|
||||
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
|
||||
}
|
||||
config.thinkingConfig = thinkingConfig;
|
||||
} else if (model.reasoning && options.thinking && !options.thinking.enabled) {
|
||||
config.thinkingConfig = getDisabledThinkingConfig(model);
|
||||
}
|
||||
|
||||
if (options.signal) {
|
||||
if (options.signal.aborted) {
|
||||
throw new Error("Request aborted");
|
||||
}
|
||||
config.abortSignal = options.signal;
|
||||
}
|
||||
|
||||
const params: GenerateContentParameters = {
|
||||
model: model.id,
|
||||
contents,
|
||||
config,
|
||||
};
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
type ClampedThinkingLevel = Exclude<PiThinkingLevel, "xhigh">;
|
||||
|
||||
function isGemini3ProModel(model: Model<"google-generative-ai">): boolean {
|
||||
return /gemini-3(?:\.\d+)?-pro/.test(model.id.toLowerCase());
|
||||
}
|
||||
|
||||
function isGemini3FlashModel(model: Model<"google-generative-ai">): boolean {
|
||||
return /gemini-3(?:\.\d+)?-flash/.test(model.id.toLowerCase());
|
||||
}
|
||||
|
||||
function getDisabledThinkingConfig(model: Model<"google-vertex">): ThinkingConfig {
|
||||
// Google docs: Gemini 3.1 Pro cannot disable thinking, and Gemini 3 Flash / Flash-Lite
|
||||
// do not support full thinking-off either. For Gemini 3 models, use the lowest supported
|
||||
// thinkingLevel without includeThoughts so hidden thinking remains invisible to pi.
|
||||
const geminiModel = model as unknown as Model<"google-generative-ai">;
|
||||
if (isGemini3ProModel(geminiModel)) {
|
||||
return { thinkingLevel: ThinkingLevel.LOW };
|
||||
}
|
||||
if (isGemini3FlashModel(geminiModel)) {
|
||||
return { thinkingLevel: ThinkingLevel.MINIMAL };
|
||||
}
|
||||
|
||||
// Gemini 2.x supports disabling via thinkingBudget = 0.
|
||||
return { thinkingBudget: 0 };
|
||||
}
|
||||
|
||||
function getGemini3ThinkingLevel(
|
||||
effort: ClampedThinkingLevel,
|
||||
model: Model<"google-generative-ai">,
|
||||
): GoogleThinkingLevel {
|
||||
if (isGemini3ProModel(model)) {
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
return "MINIMAL";
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
return "MEDIUM";
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
|
||||
function getGoogleBudget(
|
||||
model: Model<"google-generative-ai">,
|
||||
effort: ClampedThinkingLevel,
|
||||
customBudgets?: ThinkingBudgets,
|
||||
): number {
|
||||
if (customBudgets?.[effort] !== undefined) {
|
||||
return customBudgets[effort]!;
|
||||
}
|
||||
|
||||
if (model.id.includes("2.5-pro")) {
|
||||
const budgets: Record<ClampedThinkingLevel, number> = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 32768,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
if (model.id.includes("2.5-flash")) {
|
||||
const budgets: Record<ClampedThinkingLevel, number> = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 24576,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
475
packages/ai/src/providers/google.ts
Normal file
475
packages/ai/src/providers/google.ts
Normal file
|
|
@ -0,0 +1,475 @@
|
|||
import {
|
||||
type GenerateContentConfig,
|
||||
type GenerateContentParameters,
|
||||
GoogleGenAI,
|
||||
type ThinkingConfig,
|
||||
} from "@google/genai";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingBudgets,
|
||||
ThinkingContent,
|
||||
ThinkingLevel,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
|
||||
import {
|
||||
convertMessages,
|
||||
convertTools,
|
||||
isThinkingPart,
|
||||
mapStopReason,
|
||||
mapToolChoice,
|
||||
retainThoughtSignature,
|
||||
} from "./google-shared.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
|
||||
export interface GoogleOptions extends StreamOptions {
|
||||
toolChoice?: "auto" | "none" | "any";
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
budgetTokens?: number; // -1 for dynamic, 0 to disable
|
||||
level?: GoogleThinkingLevel;
|
||||
};
|
||||
}
|
||||
|
||||
// Counter for generating unique tool call IDs
|
||||
let toolCallCounter = 0;
|
||||
|
||||
export const streamGoogle: StreamFunction<"google-generative-ai", GoogleOptions> = (
|
||||
model: Model<"google-generative-ai">,
|
||||
context: Context,
|
||||
options?: GoogleOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "google-generative-ai" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const client = createClient(model, apiKey, options?.headers);
|
||||
let params = buildParams(model, context, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as GenerateContentParameters;
|
||||
}
|
||||
const googleStream = await client.models.generateContentStream(params);
|
||||
|
||||
stream.push({ type: "start", partial: output });
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
for await (const chunk of googleStream) {
|
||||
// @google/genai documents GenerateContentResponse.responseId as an output-only field
|
||||
// used to identify each response. Keep the first non-empty one from the stream.
|
||||
output.responseId ||= chunk.responseId;
|
||||
const candidate = chunk.candidates?.[0];
|
||||
if (candidate?.content?.parts) {
|
||||
for (const part of candidate.content.parts) {
|
||||
if (part.text !== undefined) {
|
||||
const isThinking = isThinkingPart(part);
|
||||
if (
|
||||
!currentBlock ||
|
||||
(isThinking && currentBlock.type !== "thinking") ||
|
||||
(!isThinking && currentBlock.type !== "text")
|
||||
) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blocks.length - 1,
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (isThinking) {
|
||||
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
|
||||
} else {
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
}
|
||||
if (currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += part.text;
|
||||
currentBlock.thinkingSignature = retainThoughtSignature(
|
||||
currentBlock.thinkingSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock.text += part.text;
|
||||
currentBlock.textSignature = retainThoughtSignature(
|
||||
currentBlock.textSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (part.functionCall) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
|
||||
// Generate unique ID if not provided or if it's a duplicate
|
||||
const providedId = part.functionCall.id;
|
||||
const needsNewId =
|
||||
!providedId || output.content.some((b) => b.type === "toolCall" && b.id === providedId);
|
||||
const toolCallId = needsNewId
|
||||
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
|
||||
: providedId;
|
||||
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: toolCallId,
|
||||
name: part.functionCall.name || "",
|
||||
arguments: (part.functionCall.args as Record<string, any>) ?? {},
|
||||
...(part.thoughtSignature && { thoughtSignature: part.thoughtSignature }),
|
||||
};
|
||||
|
||||
output.content.push(toolCall);
|
||||
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: JSON.stringify(toolCall.arguments),
|
||||
partial: output,
|
||||
});
|
||||
stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (candidate?.finishReason) {
|
||||
output.stopReason = mapStopReason(candidate.finishReason);
|
||||
if (output.content.some((b) => b.type === "toolCall")) {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
}
|
||||
|
||||
if (chunk.usageMetadata) {
|
||||
output.usage = {
|
||||
input: chunk.usageMetadata.promptTokenCount || 0,
|
||||
output:
|
||||
(chunk.usageMetadata.candidatesTokenCount || 0) + (chunk.usageMetadata.thoughtsTokenCount || 0),
|
||||
cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
// Remove internal index property used during streaming
|
||||
for (const block of output.content) {
|
||||
if ("index" in block) {
|
||||
delete (block as { index?: number }).index;
|
||||
}
|
||||
}
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleGoogle: StreamFunction<"google-generative-ai", SimpleStreamOptions> = (
|
||||
model: Model<"google-generative-ai">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
if (!options?.reasoning) {
|
||||
return streamGoogle(model, context, { ...base, thinking: { enabled: false } } satisfies GoogleOptions);
|
||||
}
|
||||
|
||||
const effort = clampReasoning(options.reasoning)!;
|
||||
const googleModel = model as Model<"google-generative-ai">;
|
||||
|
||||
if (isGemini3ProModel(googleModel) || isGemini3FlashModel(googleModel)) {
|
||||
return streamGoogle(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
level: getGemini3ThinkingLevel(effort, googleModel),
|
||||
},
|
||||
} satisfies GoogleOptions);
|
||||
}
|
||||
|
||||
return streamGoogle(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: getGoogleBudget(googleModel, effort, options.thinkingBudgets),
|
||||
},
|
||||
} satisfies GoogleOptions);
|
||||
};
|
||||
|
||||
function createClient(
|
||||
model: Model<"google-generative-ai">,
|
||||
apiKey?: string,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
): GoogleGenAI {
|
||||
const httpOptions: { baseUrl?: string; apiVersion?: string; headers?: Record<string, string> } = {};
|
||||
if (model.baseUrl) {
|
||||
httpOptions.baseUrl = model.baseUrl;
|
||||
httpOptions.apiVersion = ""; // baseUrl already includes version path, don't append
|
||||
}
|
||||
if (model.headers || optionsHeaders) {
|
||||
httpOptions.headers = { ...model.headers, ...optionsHeaders };
|
||||
}
|
||||
|
||||
return new GoogleGenAI({
|
||||
apiKey,
|
||||
httpOptions: Object.keys(httpOptions).length > 0 ? httpOptions : undefined,
|
||||
});
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"google-generative-ai">,
|
||||
context: Context,
|
||||
options: GoogleOptions = {},
|
||||
): GenerateContentParameters {
|
||||
const contents = convertMessages(model, context);
|
||||
|
||||
const generationConfig: GenerateContentConfig = {};
|
||||
if (options.temperature !== undefined) {
|
||||
generationConfig.temperature = options.temperature;
|
||||
}
|
||||
if (options.maxTokens !== undefined) {
|
||||
generationConfig.maxOutputTokens = options.maxTokens;
|
||||
}
|
||||
|
||||
const config: GenerateContentConfig = {
|
||||
...(Object.keys(generationConfig).length > 0 && generationConfig),
|
||||
...(context.systemPrompt && { systemInstruction: sanitizeSurrogates(context.systemPrompt) }),
|
||||
...(context.tools && context.tools.length > 0 && { tools: convertTools(context.tools) }),
|
||||
};
|
||||
|
||||
if (context.tools && context.tools.length > 0 && options.toolChoice) {
|
||||
config.toolConfig = {
|
||||
functionCallingConfig: {
|
||||
mode: mapToolChoice(options.toolChoice),
|
||||
},
|
||||
};
|
||||
} else {
|
||||
config.toolConfig = undefined;
|
||||
}
|
||||
|
||||
if (options.thinking?.enabled && model.reasoning) {
|
||||
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
|
||||
if (options.thinking.level !== undefined) {
|
||||
// Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values
|
||||
thinkingConfig.thinkingLevel = options.thinking.level as any;
|
||||
} else if (options.thinking.budgetTokens !== undefined) {
|
||||
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
|
||||
}
|
||||
config.thinkingConfig = thinkingConfig;
|
||||
} else if (model.reasoning && options.thinking && !options.thinking.enabled) {
|
||||
config.thinkingConfig = getDisabledThinkingConfig(model);
|
||||
}
|
||||
|
||||
if (options.signal) {
|
||||
if (options.signal.aborted) {
|
||||
throw new Error("Request aborted");
|
||||
}
|
||||
config.abortSignal = options.signal;
|
||||
}
|
||||
|
||||
const params: GenerateContentParameters = {
|
||||
model: model.id,
|
||||
contents,
|
||||
config,
|
||||
};
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
type ClampedThinkingLevel = Exclude<ThinkingLevel, "xhigh">;
|
||||
|
||||
function isGemini3ProModel(model: Model<"google-generative-ai">): boolean {
|
||||
return /gemini-3(?:\.\d+)?-pro/.test(model.id.toLowerCase());
|
||||
}
|
||||
|
||||
function isGemini3FlashModel(model: Model<"google-generative-ai">): boolean {
|
||||
return /gemini-3(?:\.\d+)?-flash/.test(model.id.toLowerCase());
|
||||
}
|
||||
|
||||
function getDisabledThinkingConfig(model: Model<"google-generative-ai">): ThinkingConfig {
|
||||
// Google docs: Gemini 3.1 Pro cannot disable thinking, and Gemini 3 Flash / Flash-Lite
|
||||
// do not support full thinking-off either. For Gemini 3 models, use the lowest supported
|
||||
// thinkingLevel without includeThoughts so hidden thinking remains invisible to pi.
|
||||
if (isGemini3ProModel(model)) {
|
||||
return { thinkingLevel: "LOW" as any };
|
||||
}
|
||||
if (isGemini3FlashModel(model)) {
|
||||
return { thinkingLevel: "MINIMAL" as any };
|
||||
}
|
||||
|
||||
// Gemini 2.x supports disabling via thinkingBudget = 0.
|
||||
return { thinkingBudget: 0 };
|
||||
}
|
||||
|
||||
function getGemini3ThinkingLevel(
|
||||
effort: ClampedThinkingLevel,
|
||||
model: Model<"google-generative-ai">,
|
||||
): GoogleThinkingLevel {
|
||||
if (isGemini3ProModel(model)) {
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
return "MINIMAL";
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
return "MEDIUM";
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
|
||||
function getGoogleBudget(
|
||||
model: Model<"google-generative-ai">,
|
||||
effort: ClampedThinkingLevel,
|
||||
customBudgets?: ThinkingBudgets,
|
||||
): number {
|
||||
if (customBudgets?.[effort] !== undefined) {
|
||||
return customBudgets[effort]!;
|
||||
}
|
||||
|
||||
if (model.id.includes("2.5-pro")) {
|
||||
const budgets: Record<ClampedThinkingLevel, number> = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 32768,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
if (model.id.includes("2.5-flash")) {
|
||||
const budgets: Record<ClampedThinkingLevel, number> = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 24576,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
585
packages/ai/src/providers/mistral.ts
Normal file
585
packages/ai/src/providers/mistral.ts
Normal file
|
|
@ -0,0 +1,585 @@
|
|||
import { Mistral } from "@mistralai/mistralai";
|
||||
import type { RequestOptions } from "@mistralai/mistralai/lib/sdks.js";
|
||||
import type {
|
||||
ChatCompletionStreamRequest,
|
||||
ChatCompletionStreamRequestMessages,
|
||||
CompletionEvent,
|
||||
ContentChunk,
|
||||
FunctionTool,
|
||||
} from "@mistralai/mistralai/models/components/index.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Message,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { shortHash } from "../utils/hash.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
const MISTRAL_TOOL_CALL_ID_LENGTH = 9;
|
||||
const MAX_MISTRAL_ERROR_BODY_CHARS = 4000;
|
||||
|
||||
/**
|
||||
* Provider-specific options for the Mistral API.
|
||||
*/
|
||||
export interface MistralOptions extends StreamOptions {
|
||||
toolChoice?: "auto" | "none" | "any" | "required" | { type: "function"; function: { name: string } };
|
||||
promptMode?: "reasoning";
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream responses from Mistral using `chat.stream`.
|
||||
*/
|
||||
export const streamMistral: StreamFunction<"mistral-conversations", MistralOptions> = (
|
||||
model: Model<"mistral-conversations">,
|
||||
context: Context,
|
||||
options?: MistralOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output = createOutput(model);
|
||||
|
||||
try {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
// Intentionally per-request: avoids shared SDK mutable state across concurrent consumers.
|
||||
const mistral = new Mistral({
|
||||
apiKey,
|
||||
serverURL: model.baseUrl,
|
||||
});
|
||||
|
||||
const normalizeMistralToolCallId = createMistralToolCallIdNormalizer();
|
||||
const transformedMessages = transformMessages(context.messages, model, (id) => normalizeMistralToolCallId(id));
|
||||
|
||||
let payload = buildChatPayload(model, context, transformedMessages, options);
|
||||
const nextPayload = await options?.onPayload?.(payload, model);
|
||||
if (nextPayload !== undefined) {
|
||||
payload = nextPayload as ChatCompletionStreamRequest;
|
||||
}
|
||||
const mistralStream = await mistral.chat.stream(payload, buildRequestOptions(model, options));
|
||||
stream.push({ type: "start", partial: output });
|
||||
await consumeChatStream(model, output, stream, mistralStream);
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = formatMistralError(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
/**
|
||||
* Maps provider-agnostic `SimpleStreamOptions` to Mistral options.
|
||||
*/
|
||||
export const streamSimpleMistral: StreamFunction<"mistral-conversations", SimpleStreamOptions> = (
|
||||
model: Model<"mistral-conversations">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const reasoning = clampReasoning(options?.reasoning);
|
||||
|
||||
return streamMistral(model, context, {
|
||||
...base,
|
||||
promptMode: model.reasoning && reasoning ? "reasoning" : undefined,
|
||||
} satisfies MistralOptions);
|
||||
};
|
||||
|
||||
function createOutput(model: Model<"mistral-conversations">): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
function createMistralToolCallIdNormalizer(): (id: string) => string {
|
||||
const idMap = new Map<string, string>();
|
||||
const reverseMap = new Map<string, string>();
|
||||
|
||||
return (id: string): string => {
|
||||
const existing = idMap.get(id);
|
||||
if (existing) return existing;
|
||||
|
||||
let attempt = 0;
|
||||
while (true) {
|
||||
const candidate = deriveMistralToolCallId(id, attempt);
|
||||
const owner = reverseMap.get(candidate);
|
||||
if (!owner || owner === id) {
|
||||
idMap.set(id, candidate);
|
||||
reverseMap.set(candidate, id);
|
||||
return candidate;
|
||||
}
|
||||
attempt++;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
function deriveMistralToolCallId(id: string, attempt: number): string {
|
||||
const normalized = id.replace(/[^a-zA-Z0-9]/g, "");
|
||||
if (attempt === 0 && normalized.length === MISTRAL_TOOL_CALL_ID_LENGTH) return normalized;
|
||||
const seedBase = normalized || id;
|
||||
const seed = attempt === 0 ? seedBase : `${seedBase}:${attempt}`;
|
||||
return shortHash(seed)
|
||||
.replace(/[^a-zA-Z0-9]/g, "")
|
||||
.slice(0, MISTRAL_TOOL_CALL_ID_LENGTH);
|
||||
}
|
||||
|
||||
function formatMistralError(error: unknown): string {
|
||||
if (error instanceof Error) {
|
||||
const sdkError = error as Error & { statusCode?: unknown; body?: unknown };
|
||||
const statusCode = typeof sdkError.statusCode === "number" ? sdkError.statusCode : undefined;
|
||||
const bodyText = typeof sdkError.body === "string" ? sdkError.body.trim() : undefined;
|
||||
if (statusCode !== undefined && bodyText) {
|
||||
return `Mistral API error (${statusCode}): ${truncateErrorText(bodyText, MAX_MISTRAL_ERROR_BODY_CHARS)}`;
|
||||
}
|
||||
if (statusCode !== undefined) return `Mistral API error (${statusCode}): ${error.message}`;
|
||||
return error.message;
|
||||
}
|
||||
return safeJsonStringify(error);
|
||||
}
|
||||
|
||||
function truncateErrorText(text: string, maxChars: number): string {
|
||||
if (text.length <= maxChars) return text;
|
||||
return `${text.slice(0, maxChars)}... [truncated ${text.length - maxChars} chars]`;
|
||||
}
|
||||
|
||||
function safeJsonStringify(value: unknown): string {
|
||||
try {
|
||||
const serialized = JSON.stringify(value);
|
||||
return serialized === undefined ? String(value) : serialized;
|
||||
} catch {
|
||||
return String(value);
|
||||
}
|
||||
}
|
||||
|
||||
function buildRequestOptions(model: Model<"mistral-conversations">, options?: MistralOptions): RequestOptions {
|
||||
const requestOptions: RequestOptions = {};
|
||||
if (options?.signal) requestOptions.signal = options.signal;
|
||||
requestOptions.retries = { strategy: "none" };
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
if (model.headers) Object.assign(headers, model.headers);
|
||||
if (options?.headers) Object.assign(headers, options.headers);
|
||||
|
||||
// Mistral infrastructure uses `x-affinity` for KV-cache reuse (prefix caching).
|
||||
// Respect explicit caller-provided header values.
|
||||
if (options?.sessionId && !headers["x-affinity"]) {
|
||||
headers["x-affinity"] = options.sessionId;
|
||||
}
|
||||
|
||||
if (Object.keys(headers).length > 0) {
|
||||
requestOptions.headers = headers;
|
||||
}
|
||||
|
||||
return requestOptions;
|
||||
}
|
||||
|
||||
function buildChatPayload(
|
||||
model: Model<"mistral-conversations">,
|
||||
context: Context,
|
||||
messages: Message[],
|
||||
options?: MistralOptions,
|
||||
): ChatCompletionStreamRequest {
|
||||
const payload: ChatCompletionStreamRequest = {
|
||||
model: model.id,
|
||||
stream: true,
|
||||
messages: toChatMessages(messages, model.input.includes("image")),
|
||||
};
|
||||
|
||||
if (context.tools?.length) payload.tools = toFunctionTools(context.tools);
|
||||
if (options?.temperature !== undefined) payload.temperature = options.temperature;
|
||||
if (options?.maxTokens !== undefined) payload.maxTokens = options.maxTokens;
|
||||
if (options?.toolChoice) payload.toolChoice = mapToolChoice(options.toolChoice);
|
||||
if (options?.promptMode) payload.promptMode = options.promptMode as any;
|
||||
|
||||
if (context.systemPrompt) {
|
||||
payload.messages.unshift({
|
||||
role: "system",
|
||||
content: sanitizeSurrogates(context.systemPrompt),
|
||||
});
|
||||
}
|
||||
|
||||
return payload;
|
||||
}
|
||||
|
||||
async function consumeChatStream(
|
||||
model: Model<"mistral-conversations">,
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
mistralStream: AsyncIterable<CompletionEvent>,
|
||||
): Promise<void> {
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
const toolBlocksByKey = new Map<string, number>();
|
||||
|
||||
const finishCurrentBlock = (block?: typeof currentBlock) => {
|
||||
if (!block) return;
|
||||
if (block.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: block.text,
|
||||
partial: output,
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (block.type === "thinking") {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: block.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
for await (const event of mistralStream) {
|
||||
const chunk = event.data;
|
||||
// Mistral's streamed CompletionChunk carries an id field. Keep the first non-empty one,
|
||||
// mirroring how OpenAI-style streaming exposes a stable response identifier per stream.
|
||||
output.responseId ||= chunk.id;
|
||||
|
||||
if (chunk.usage) {
|
||||
output.usage.input = chunk.usage.promptTokens || 0;
|
||||
output.usage.output = chunk.usage.completionTokens || 0;
|
||||
output.usage.cacheRead = 0;
|
||||
output.usage.cacheWrite = 0;
|
||||
output.usage.totalTokens = chunk.usage.totalTokens || output.usage.input + output.usage.output;
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
|
||||
const choice = chunk.choices[0];
|
||||
if (!choice) continue;
|
||||
|
||||
if (choice.finishReason) {
|
||||
output.stopReason = mapChatStopReason(choice.finishReason);
|
||||
}
|
||||
|
||||
const delta = choice.delta;
|
||||
if (delta.content !== null && delta.content !== undefined) {
|
||||
const contentItems = typeof delta.content === "string" ? [delta.content] : delta.content;
|
||||
for (const item of contentItems) {
|
||||
if (typeof item === "string") {
|
||||
const textDelta = sanitizeSurrogates(item);
|
||||
if (!currentBlock || currentBlock.type !== "text") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
currentBlock.text += textDelta;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: textDelta,
|
||||
partial: output,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (item.type === "thinking") {
|
||||
const deltaText = item.thinking
|
||||
.map((part) => ("text" in part ? part.text : ""))
|
||||
.filter((text) => text.length > 0)
|
||||
.join("");
|
||||
const thinkingDelta = sanitizeSurrogates(deltaText);
|
||||
if (!thinkingDelta) continue;
|
||||
if (!currentBlock || currentBlock.type !== "thinking") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = { type: "thinking", thinking: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
currentBlock.thinking += thinkingDelta;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: thinkingDelta,
|
||||
partial: output,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (item.type === "text") {
|
||||
const textDelta = sanitizeSurrogates(item.text);
|
||||
if (!currentBlock || currentBlock.type !== "text") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
currentBlock.text += textDelta;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: textDelta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const toolCalls = delta.toolCalls || [];
|
||||
for (const toolCall of toolCalls) {
|
||||
if (currentBlock) {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = null;
|
||||
}
|
||||
const callId =
|
||||
toolCall.id && toolCall.id !== "null"
|
||||
? toolCall.id
|
||||
: deriveMistralToolCallId(`toolcall:${toolCall.index ?? 0}`, 0);
|
||||
const key = `${callId}:${toolCall.index || 0}`;
|
||||
const existingIndex = toolBlocksByKey.get(key);
|
||||
let block: (ToolCall & { partialArgs?: string }) | undefined;
|
||||
|
||||
if (existingIndex !== undefined) {
|
||||
const existing = output.content[existingIndex];
|
||||
if (existing?.type === "toolCall") {
|
||||
block = existing as ToolCall & { partialArgs?: string };
|
||||
}
|
||||
}
|
||||
|
||||
if (!block) {
|
||||
block = {
|
||||
type: "toolCall",
|
||||
id: callId,
|
||||
name: toolCall.function.name,
|
||||
arguments: {},
|
||||
partialArgs: "",
|
||||
};
|
||||
output.content.push(block);
|
||||
toolBlocksByKey.set(key, output.content.length - 1);
|
||||
stream.push({ type: "toolcall_start", contentIndex: output.content.length - 1, partial: output });
|
||||
}
|
||||
|
||||
const argsDelta =
|
||||
typeof toolCall.function.arguments === "string"
|
||||
? toolCall.function.arguments
|
||||
: JSON.stringify(toolCall.function.arguments || {});
|
||||
block.partialArgs = (block.partialArgs || "") + argsDelta;
|
||||
block.arguments = parseStreamingJson<Record<string, unknown>>(block.partialArgs);
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: toolBlocksByKey.get(key)!,
|
||||
delta: argsDelta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
finishCurrentBlock(currentBlock);
|
||||
for (const index of toolBlocksByKey.values()) {
|
||||
const block = output.content[index];
|
||||
if (block.type !== "toolCall") continue;
|
||||
const toolBlock = block as ToolCall & { partialArgs?: string };
|
||||
toolBlock.arguments = parseStreamingJson<Record<string, unknown>>(toolBlock.partialArgs);
|
||||
delete toolBlock.partialArgs;
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: index,
|
||||
toolCall: toolBlock,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function toFunctionTools(tools: Tool[]): Array<FunctionTool & { type: "function" }> {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters as unknown as Record<string, unknown>,
|
||||
strict: false,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
function toChatMessages(messages: Message[], supportsImages: boolean): ChatCompletionStreamRequestMessages[] {
|
||||
const result: ChatCompletionStreamRequestMessages[] = [];
|
||||
|
||||
for (const msg of messages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
result.push({ role: "user", content: sanitizeSurrogates(msg.content) });
|
||||
continue;
|
||||
}
|
||||
const hadImages = msg.content.some((item) => item.type === "image");
|
||||
const content: ContentChunk[] = msg.content
|
||||
.filter((item) => item.type === "text" || supportsImages)
|
||||
.map((item) => {
|
||||
if (item.type === "text") return { type: "text", text: sanitizeSurrogates(item.text) };
|
||||
return { type: "image_url", imageUrl: `data:${item.mimeType};base64,${item.data}` };
|
||||
});
|
||||
if (content.length > 0) {
|
||||
result.push({ role: "user", content });
|
||||
continue;
|
||||
}
|
||||
if (hadImages && !supportsImages) {
|
||||
result.push({ role: "user", content: "(image omitted: model does not support images)" });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (msg.role === "assistant") {
|
||||
const contentParts: ContentChunk[] = [];
|
||||
const toolCalls: Array<{ id: string; type: "function"; function: { name: string; arguments: string } }> = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
if (block.text.trim().length > 0) {
|
||||
contentParts.push({ type: "text", text: sanitizeSurrogates(block.text) });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (block.type === "thinking") {
|
||||
if (block.thinking.trim().length > 0) {
|
||||
contentParts.push({
|
||||
type: "thinking",
|
||||
thinking: [{ type: "text", text: sanitizeSurrogates(block.thinking) }],
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
toolCalls.push({
|
||||
id: block.id,
|
||||
type: "function",
|
||||
function: { name: block.name, arguments: JSON.stringify(block.arguments || {}) },
|
||||
});
|
||||
}
|
||||
|
||||
const assistantMessage: ChatCompletionStreamRequestMessages = { role: "assistant" };
|
||||
if (contentParts.length > 0) assistantMessage.content = contentParts;
|
||||
if (toolCalls.length > 0) assistantMessage.toolCalls = toolCalls;
|
||||
if (contentParts.length > 0 || toolCalls.length > 0) result.push(assistantMessage);
|
||||
continue;
|
||||
}
|
||||
|
||||
const toolContent: ContentChunk[] = [];
|
||||
const textResult = msg.content
|
||||
.filter((part) => part.type === "text")
|
||||
.map((part) => (part.type === "text" ? sanitizeSurrogates(part.text) : ""))
|
||||
.join("\n");
|
||||
const hasImages = msg.content.some((part) => part.type === "image");
|
||||
const toolText = buildToolResultText(textResult, hasImages, supportsImages, msg.isError);
|
||||
toolContent.push({ type: "text", text: toolText });
|
||||
for (const part of msg.content) {
|
||||
if (!supportsImages) continue;
|
||||
if (part.type !== "image") continue;
|
||||
toolContent.push({
|
||||
type: "image_url",
|
||||
imageUrl: `data:${part.mimeType};base64,${part.data}`,
|
||||
});
|
||||
}
|
||||
result.push({
|
||||
role: "tool",
|
||||
toolCallId: msg.toolCallId,
|
||||
name: msg.toolName,
|
||||
content: toolContent,
|
||||
});
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
function buildToolResultText(text: string, hasImages: boolean, supportsImages: boolean, isError: boolean): string {
|
||||
const trimmed = text.trim();
|
||||
const errorPrefix = isError ? "[tool error] " : "";
|
||||
|
||||
if (trimmed.length > 0) {
|
||||
const imageSuffix = hasImages && !supportsImages ? "\n[tool image omitted: model does not support images]" : "";
|
||||
return `${errorPrefix}${trimmed}${imageSuffix}`;
|
||||
}
|
||||
|
||||
if (hasImages) {
|
||||
if (supportsImages) {
|
||||
return isError ? "[tool error] (see attached image)" : "(see attached image)";
|
||||
}
|
||||
return isError
|
||||
? "[tool error] (image omitted: model does not support images)"
|
||||
: "(image omitted: model does not support images)";
|
||||
}
|
||||
|
||||
return isError ? "[tool error] (no tool output)" : "(no tool output)";
|
||||
}
|
||||
|
||||
function mapToolChoice(
|
||||
choice: MistralOptions["toolChoice"],
|
||||
): "auto" | "none" | "any" | "required" | { type: "function"; function: { name: string } } | undefined {
|
||||
if (!choice) return undefined;
|
||||
if (choice === "auto" || choice === "none" || choice === "any" || choice === "required") {
|
||||
return choice as any;
|
||||
}
|
||||
return {
|
||||
type: "function",
|
||||
function: { name: choice.function.name },
|
||||
};
|
||||
}
|
||||
|
||||
function mapChatStopReason(reason: string | null): StopReason {
|
||||
if (reason === null) return "stop";
|
||||
switch (reason) {
|
||||
case "stop":
|
||||
return "stop";
|
||||
case "length":
|
||||
case "model_length":
|
||||
return "length";
|
||||
case "tool_calls":
|
||||
return "toolUse";
|
||||
case "error":
|
||||
return "error";
|
||||
default:
|
||||
return "stop";
|
||||
}
|
||||
}
|
||||
929
packages/ai/src/providers/openai-codex-responses.ts
Normal file
929
packages/ai/src/providers/openai-codex-responses.ts
Normal file
|
|
@ -0,0 +1,929 @@
|
|||
import type * as NodeOs from "node:os";
|
||||
import type { Tool as OpenAITool, ResponseInput, ResponseStreamEvent } from "openai/resources/responses/responses.js";
|
||||
|
||||
// NEVER convert to top-level runtime imports - breaks browser/Vite builds (web-ui)
|
||||
let _os: typeof NodeOs | null = null;
|
||||
|
||||
type DynamicImport = (specifier: string) => Promise<unknown>;
|
||||
|
||||
const dynamicImport: DynamicImport = (specifier) => import(specifier);
|
||||
const NODE_OS_SPECIFIER = "node:" + "os";
|
||||
|
||||
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
|
||||
dynamicImport(NODE_OS_SPECIFIER).then((m) => {
|
||||
_os = m as typeof NodeOs;
|
||||
});
|
||||
}
|
||||
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { supportsXhigh } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { convertResponsesMessages, convertResponsesTools, processResponsesStream } from "./openai-responses-shared.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
|
||||
// ============================================================================
|
||||
// Configuration
|
||||
// ============================================================================
|
||||
|
||||
const DEFAULT_CODEX_BASE_URL = "https://chatgpt.com/backend-api";
|
||||
const JWT_CLAIM_PATH = "https://api.openai.com/auth" as const;
|
||||
const MAX_RETRIES = 3;
|
||||
const BASE_DELAY_MS = 1000;
|
||||
const CODEX_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode"]);
|
||||
|
||||
const CODEX_RESPONSE_STATUSES = new Set<CodexResponseStatus>([
|
||||
"completed",
|
||||
"incomplete",
|
||||
"failed",
|
||||
"cancelled",
|
||||
"queued",
|
||||
"in_progress",
|
||||
]);
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export interface OpenAICodexResponsesOptions extends StreamOptions {
|
||||
reasoningEffort?: "none" | "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
reasoningSummary?: "auto" | "concise" | "detailed" | "off" | "on" | null;
|
||||
textVerbosity?: "low" | "medium" | "high";
|
||||
}
|
||||
|
||||
type CodexResponseStatus = "completed" | "incomplete" | "failed" | "cancelled" | "queued" | "in_progress";
|
||||
|
||||
interface RequestBody {
|
||||
model: string;
|
||||
store?: boolean;
|
||||
stream?: boolean;
|
||||
instructions?: string;
|
||||
input?: ResponseInput;
|
||||
tools?: OpenAITool[];
|
||||
tool_choice?: "auto";
|
||||
parallel_tool_calls?: boolean;
|
||||
temperature?: number;
|
||||
reasoning?: { effort?: string; summary?: string };
|
||||
text?: { verbosity?: string };
|
||||
include?: string[];
|
||||
prompt_cache_key?: string;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Retry Helpers
|
||||
// ============================================================================
|
||||
|
||||
function isRetryableError(status: number, errorText: string): boolean {
|
||||
if (status === 429 || status === 500 || status === 502 || status === 503 || status === 504) {
|
||||
return true;
|
||||
}
|
||||
return /rate.?limit|overloaded|service.?unavailable|upstream.?connect|connection.?refused/i.test(errorText);
|
||||
}
|
||||
|
||||
function sleep(ms: number, signal?: AbortSignal): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Request was aborted"));
|
||||
return;
|
||||
}
|
||||
const timeout = setTimeout(resolve, ms);
|
||||
signal?.addEventListener("abort", () => {
|
||||
clearTimeout(timeout);
|
||||
reject(new Error("Request was aborted"));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main Stream Function
|
||||
// ============================================================================
|
||||
|
||||
export const streamOpenAICodexResponses: StreamFunction<"openai-codex-responses", OpenAICodexResponsesOptions> = (
|
||||
model: Model<"openai-codex-responses">,
|
||||
context: Context,
|
||||
options?: OpenAICodexResponsesOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "openai-codex-responses" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const accountId = extractAccountId(apiKey);
|
||||
let body = buildRequestBody(model, context, options);
|
||||
const nextBody = await options?.onPayload?.(body, model);
|
||||
if (nextBody !== undefined) {
|
||||
body = nextBody as RequestBody;
|
||||
}
|
||||
const websocketRequestId = options?.sessionId || createCodexRequestId();
|
||||
const sseHeaders = buildSSEHeaders(model.headers, options?.headers, accountId, apiKey, options?.sessionId);
|
||||
const websocketHeaders = buildWebSocketHeaders(
|
||||
model.headers,
|
||||
options?.headers,
|
||||
accountId,
|
||||
apiKey,
|
||||
websocketRequestId,
|
||||
);
|
||||
const bodyJson = JSON.stringify(body);
|
||||
const transport = options?.transport || "sse";
|
||||
|
||||
if (transport !== "sse") {
|
||||
let websocketStarted = false;
|
||||
try {
|
||||
await processWebSocketStream(
|
||||
resolveCodexWebSocketUrl(model.baseUrl),
|
||||
body,
|
||||
websocketHeaders,
|
||||
output,
|
||||
stream,
|
||||
model,
|
||||
() => {
|
||||
websocketStarted = true;
|
||||
},
|
||||
options,
|
||||
);
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
stream.push({
|
||||
type: "done",
|
||||
reason: output.stopReason as "stop" | "length" | "toolUse",
|
||||
message: output,
|
||||
});
|
||||
stream.end();
|
||||
return;
|
||||
} catch (error) {
|
||||
if (transport === "websocket" || websocketStarted) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch with retry logic for rate limits and transient errors
|
||||
let response: Response | undefined;
|
||||
let lastError: Error | undefined;
|
||||
|
||||
for (let attempt = 0; attempt <= MAX_RETRIES; attempt++) {
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
try {
|
||||
response = await fetch(resolveCodexUrl(model.baseUrl), {
|
||||
method: "POST",
|
||||
headers: sseHeaders,
|
||||
body: bodyJson,
|
||||
signal: options?.signal,
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
break;
|
||||
}
|
||||
|
||||
const errorText = await response.text();
|
||||
if (attempt < MAX_RETRIES && isRetryableError(response.status, errorText)) {
|
||||
const delayMs = BASE_DELAY_MS * 2 ** attempt;
|
||||
await sleep(delayMs, options?.signal);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse error for friendly message on final attempt or non-retryable error
|
||||
const fakeResponse = new Response(errorText, {
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
});
|
||||
const info = await parseErrorResponse(fakeResponse);
|
||||
throw new Error(info.friendlyMessage || info.message);
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
if (error.name === "AbortError" || error.message === "Request was aborted") {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
}
|
||||
lastError = error instanceof Error ? error : new Error(String(error));
|
||||
// Network errors are retryable
|
||||
if (attempt < MAX_RETRIES && !lastError.message.includes("usage limit")) {
|
||||
const delayMs = BASE_DELAY_MS * 2 ** attempt;
|
||||
await sleep(delayMs, options?.signal);
|
||||
continue;
|
||||
}
|
||||
throw lastError;
|
||||
}
|
||||
}
|
||||
|
||||
if (!response?.ok) {
|
||||
throw lastError ?? new Error("Failed after retries");
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
stream.push({ type: "start", partial: output });
|
||||
await processStream(response, output, stream, model);
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason as "stop" | "length" | "toolUse", message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : String(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleOpenAICodexResponses: StreamFunction<"openai-codex-responses", SimpleStreamOptions> = (
|
||||
model: Model<"openai-codex-responses">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const reasoningEffort = supportsXhigh(model) ? options?.reasoning : clampReasoning(options?.reasoning);
|
||||
|
||||
return streamOpenAICodexResponses(model, context, {
|
||||
...base,
|
||||
reasoningEffort,
|
||||
} satisfies OpenAICodexResponsesOptions);
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Request Building
|
||||
// ============================================================================
|
||||
|
||||
function buildRequestBody(
|
||||
model: Model<"openai-codex-responses">,
|
||||
context: Context,
|
||||
options?: OpenAICodexResponsesOptions,
|
||||
): RequestBody {
|
||||
const messages = convertResponsesMessages(model, context, CODEX_TOOL_CALL_PROVIDERS, {
|
||||
includeSystemPrompt: false,
|
||||
});
|
||||
|
||||
const body: RequestBody = {
|
||||
model: model.id,
|
||||
store: false,
|
||||
stream: true,
|
||||
instructions: context.systemPrompt,
|
||||
input: messages,
|
||||
text: { verbosity: options?.textVerbosity || "medium" },
|
||||
include: ["reasoning.encrypted_content"],
|
||||
prompt_cache_key: options?.sessionId,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: true,
|
||||
};
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
body.temperature = options.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
body.tools = convertResponsesTools(context.tools, { strict: null });
|
||||
}
|
||||
|
||||
if (options?.reasoningEffort !== undefined) {
|
||||
body.reasoning = {
|
||||
effort: clampReasoningEffort(model.id, options.reasoningEffort),
|
||||
summary: options.reasoningSummary ?? "auto",
|
||||
};
|
||||
}
|
||||
|
||||
return body;
|
||||
}
|
||||
|
||||
function clampReasoningEffort(modelId: string, effort: string): string {
|
||||
const id = modelId.includes("/") ? modelId.split("/").pop()! : modelId;
|
||||
if ((id.startsWith("gpt-5.2") || id.startsWith("gpt-5.3") || id.startsWith("gpt-5.4")) && effort === "minimal")
|
||||
return "low";
|
||||
if (id === "gpt-5.1" && effort === "xhigh") return "high";
|
||||
if (id === "gpt-5.1-codex-mini") return effort === "high" || effort === "xhigh" ? "high" : "medium";
|
||||
return effort;
|
||||
}
|
||||
|
||||
function resolveCodexUrl(baseUrl?: string): string {
|
||||
const raw = baseUrl && baseUrl.trim().length > 0 ? baseUrl : DEFAULT_CODEX_BASE_URL;
|
||||
const normalized = raw.replace(/\/+$/, "");
|
||||
if (normalized.endsWith("/codex/responses")) return normalized;
|
||||
if (normalized.endsWith("/codex")) return `${normalized}/responses`;
|
||||
return `${normalized}/codex/responses`;
|
||||
}
|
||||
|
||||
function resolveCodexWebSocketUrl(baseUrl?: string): string {
|
||||
const url = new URL(resolveCodexUrl(baseUrl));
|
||||
if (url.protocol === "https:") url.protocol = "wss:";
|
||||
if (url.protocol === "http:") url.protocol = "ws:";
|
||||
return url.toString();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Response Processing
|
||||
// ============================================================================
|
||||
|
||||
async function processStream(
|
||||
response: Response,
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
model: Model<"openai-codex-responses">,
|
||||
): Promise<void> {
|
||||
await processResponsesStream(mapCodexEvents(parseSSE(response)), output, stream, model);
|
||||
}
|
||||
|
||||
async function* mapCodexEvents(events: AsyncIterable<Record<string, unknown>>): AsyncGenerator<ResponseStreamEvent> {
|
||||
for await (const event of events) {
|
||||
const type = typeof event.type === "string" ? event.type : undefined;
|
||||
if (!type) continue;
|
||||
|
||||
if (type === "error") {
|
||||
const code = (event as { code?: string }).code || "";
|
||||
const message = (event as { message?: string }).message || "";
|
||||
throw new Error(`Codex error: ${message || code || JSON.stringify(event)}`);
|
||||
}
|
||||
|
||||
if (type === "response.failed") {
|
||||
const msg = (event as { response?: { error?: { message?: string } } }).response?.error?.message;
|
||||
throw new Error(msg || "Codex response failed");
|
||||
}
|
||||
|
||||
if (type === "response.done" || type === "response.completed" || type === "response.incomplete") {
|
||||
const response = (event as { response?: { status?: unknown } }).response;
|
||||
const normalizedResponse = response
|
||||
? { ...response, status: normalizeCodexStatus(response.status) }
|
||||
: response;
|
||||
yield { ...event, type: "response.completed", response: normalizedResponse } as ResponseStreamEvent;
|
||||
return;
|
||||
}
|
||||
|
||||
yield event as unknown as ResponseStreamEvent;
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeCodexStatus(status: unknown): CodexResponseStatus | undefined {
|
||||
if (typeof status !== "string") return undefined;
|
||||
return CODEX_RESPONSE_STATUSES.has(status as CodexResponseStatus) ? (status as CodexResponseStatus) : undefined;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SSE Parsing
|
||||
// ============================================================================
|
||||
|
||||
async function* parseSSE(response: Response): AsyncGenerator<Record<string, unknown>> {
|
||||
if (!response.body) return;
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
let idx = buffer.indexOf("\n\n");
|
||||
while (idx !== -1) {
|
||||
const chunk = buffer.slice(0, idx);
|
||||
buffer = buffer.slice(idx + 2);
|
||||
|
||||
const dataLines = chunk
|
||||
.split("\n")
|
||||
.filter((l) => l.startsWith("data:"))
|
||||
.map((l) => l.slice(5).trim());
|
||||
if (dataLines.length > 0) {
|
||||
const data = dataLines.join("\n").trim();
|
||||
if (data && data !== "[DONE]") {
|
||||
try {
|
||||
yield JSON.parse(data);
|
||||
} catch {}
|
||||
}
|
||||
}
|
||||
idx = buffer.indexOf("\n\n");
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
try {
|
||||
await reader.cancel();
|
||||
} catch {}
|
||||
try {
|
||||
reader.releaseLock();
|
||||
} catch {}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WebSocket Parsing
|
||||
// ============================================================================
|
||||
|
||||
const OPENAI_BETA_RESPONSES_WEBSOCKETS = "responses_websockets=2026-02-06";
|
||||
const SESSION_WEBSOCKET_CACHE_TTL_MS = 5 * 60 * 1000;
|
||||
|
||||
type WebSocketEventType = "open" | "message" | "error" | "close";
|
||||
type WebSocketListener = (event: unknown) => void;
|
||||
|
||||
interface WebSocketLike {
|
||||
close(code?: number, reason?: string): void;
|
||||
send(data: string): void;
|
||||
addEventListener(type: WebSocketEventType, listener: WebSocketListener): void;
|
||||
removeEventListener(type: WebSocketEventType, listener: WebSocketListener): void;
|
||||
}
|
||||
|
||||
interface CachedWebSocketConnection {
|
||||
socket: WebSocketLike;
|
||||
busy: boolean;
|
||||
idleTimer?: ReturnType<typeof setTimeout>;
|
||||
}
|
||||
|
||||
const websocketSessionCache = new Map<string, CachedWebSocketConnection>();
|
||||
|
||||
type WebSocketConstructor = new (
|
||||
url: string,
|
||||
protocols?: string | string[] | { headers?: Record<string, string> },
|
||||
) => WebSocketLike;
|
||||
|
||||
function getWebSocketConstructor(): WebSocketConstructor | null {
|
||||
const ctor = (globalThis as { WebSocket?: unknown }).WebSocket;
|
||||
if (typeof ctor !== "function") return null;
|
||||
return ctor as unknown as WebSocketConstructor;
|
||||
}
|
||||
|
||||
function headersToRecord(headers: Headers): Record<string, string> {
|
||||
const out: Record<string, string> = {};
|
||||
for (const [key, value] of headers.entries()) {
|
||||
out[key] = value;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function getWebSocketReadyState(socket: WebSocketLike): number | undefined {
|
||||
const readyState = (socket as { readyState?: unknown }).readyState;
|
||||
return typeof readyState === "number" ? readyState : undefined;
|
||||
}
|
||||
|
||||
function isWebSocketReusable(socket: WebSocketLike): boolean {
|
||||
const readyState = getWebSocketReadyState(socket);
|
||||
// If readyState is unavailable, assume the runtime keeps it open/reusable.
|
||||
return readyState === undefined || readyState === 1;
|
||||
}
|
||||
|
||||
function closeWebSocketSilently(socket: WebSocketLike, code = 1000, reason = "done"): void {
|
||||
try {
|
||||
socket.close(code, reason);
|
||||
} catch {}
|
||||
}
|
||||
|
||||
function scheduleSessionWebSocketExpiry(sessionId: string, entry: CachedWebSocketConnection): void {
|
||||
if (entry.idleTimer) {
|
||||
clearTimeout(entry.idleTimer);
|
||||
}
|
||||
entry.idleTimer = setTimeout(() => {
|
||||
if (entry.busy) return;
|
||||
closeWebSocketSilently(entry.socket, 1000, "idle_timeout");
|
||||
websocketSessionCache.delete(sessionId);
|
||||
}, SESSION_WEBSOCKET_CACHE_TTL_MS);
|
||||
}
|
||||
|
||||
async function connectWebSocket(url: string, headers: Headers, signal?: AbortSignal): Promise<WebSocketLike> {
|
||||
const WebSocketCtor = getWebSocketConstructor();
|
||||
if (!WebSocketCtor) {
|
||||
throw new Error("WebSocket transport is not available in this runtime");
|
||||
}
|
||||
|
||||
const wsHeaders = headersToRecord(headers);
|
||||
delete wsHeaders["OpenAI-Beta"];
|
||||
|
||||
return new Promise<WebSocketLike>((resolve, reject) => {
|
||||
let settled = false;
|
||||
let socket: WebSocketLike;
|
||||
|
||||
try {
|
||||
socket = new WebSocketCtor(url, { headers: wsHeaders });
|
||||
} catch (error) {
|
||||
reject(error instanceof Error ? error : new Error(String(error)));
|
||||
return;
|
||||
}
|
||||
|
||||
const onOpen: WebSocketListener = () => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
cleanup();
|
||||
resolve(socket);
|
||||
};
|
||||
const onError: WebSocketListener = (event) => {
|
||||
const error = extractWebSocketError(event);
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
cleanup();
|
||||
reject(error);
|
||||
};
|
||||
const onClose: WebSocketListener = (event) => {
|
||||
const error = extractWebSocketCloseError(event);
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
cleanup();
|
||||
reject(error);
|
||||
};
|
||||
const onAbort = () => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
cleanup();
|
||||
socket.close(1000, "aborted");
|
||||
reject(new Error("Request was aborted"));
|
||||
};
|
||||
|
||||
const cleanup = () => {
|
||||
socket.removeEventListener("open", onOpen);
|
||||
socket.removeEventListener("error", onError);
|
||||
socket.removeEventListener("close", onClose);
|
||||
signal?.removeEventListener("abort", onAbort);
|
||||
};
|
||||
|
||||
socket.addEventListener("open", onOpen);
|
||||
socket.addEventListener("error", onError);
|
||||
socket.addEventListener("close", onClose);
|
||||
signal?.addEventListener("abort", onAbort);
|
||||
});
|
||||
}
|
||||
|
||||
async function acquireWebSocket(
|
||||
url: string,
|
||||
headers: Headers,
|
||||
sessionId: string | undefined,
|
||||
signal?: AbortSignal,
|
||||
): Promise<{ socket: WebSocketLike; release: (options?: { keep?: boolean }) => void }> {
|
||||
if (!sessionId) {
|
||||
const socket = await connectWebSocket(url, headers, signal);
|
||||
return {
|
||||
socket,
|
||||
release: ({ keep } = {}) => {
|
||||
if (keep === false) {
|
||||
closeWebSocketSilently(socket);
|
||||
return;
|
||||
}
|
||||
closeWebSocketSilently(socket);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const cached = websocketSessionCache.get(sessionId);
|
||||
if (cached) {
|
||||
if (cached.idleTimer) {
|
||||
clearTimeout(cached.idleTimer);
|
||||
cached.idleTimer = undefined;
|
||||
}
|
||||
if (!cached.busy && isWebSocketReusable(cached.socket)) {
|
||||
cached.busy = true;
|
||||
return {
|
||||
socket: cached.socket,
|
||||
release: ({ keep } = {}) => {
|
||||
if (!keep || !isWebSocketReusable(cached.socket)) {
|
||||
closeWebSocketSilently(cached.socket);
|
||||
websocketSessionCache.delete(sessionId);
|
||||
return;
|
||||
}
|
||||
cached.busy = false;
|
||||
scheduleSessionWebSocketExpiry(sessionId, cached);
|
||||
},
|
||||
};
|
||||
}
|
||||
if (cached.busy) {
|
||||
const socket = await connectWebSocket(url, headers, signal);
|
||||
return {
|
||||
socket,
|
||||
release: () => {
|
||||
closeWebSocketSilently(socket);
|
||||
},
|
||||
};
|
||||
}
|
||||
if (!isWebSocketReusable(cached.socket)) {
|
||||
closeWebSocketSilently(cached.socket);
|
||||
websocketSessionCache.delete(sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
const socket = await connectWebSocket(url, headers, signal);
|
||||
const entry: CachedWebSocketConnection = { socket, busy: true };
|
||||
websocketSessionCache.set(sessionId, entry);
|
||||
return {
|
||||
socket,
|
||||
release: ({ keep } = {}) => {
|
||||
if (!keep || !isWebSocketReusable(entry.socket)) {
|
||||
closeWebSocketSilently(entry.socket);
|
||||
if (entry.idleTimer) clearTimeout(entry.idleTimer);
|
||||
if (websocketSessionCache.get(sessionId) === entry) {
|
||||
websocketSessionCache.delete(sessionId);
|
||||
}
|
||||
return;
|
||||
}
|
||||
entry.busy = false;
|
||||
scheduleSessionWebSocketExpiry(sessionId, entry);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function extractWebSocketError(event: unknown): Error {
|
||||
if (event && typeof event === "object" && "message" in event) {
|
||||
const message = (event as { message?: unknown }).message;
|
||||
if (typeof message === "string" && message.length > 0) {
|
||||
return new Error(message);
|
||||
}
|
||||
}
|
||||
return new Error("WebSocket error");
|
||||
}
|
||||
|
||||
function extractWebSocketCloseError(event: unknown): Error {
|
||||
if (event && typeof event === "object") {
|
||||
const code = "code" in event ? (event as { code?: unknown }).code : undefined;
|
||||
const reason = "reason" in event ? (event as { reason?: unknown }).reason : undefined;
|
||||
const codeText = typeof code === "number" ? ` ${code}` : "";
|
||||
const reasonText = typeof reason === "string" && reason.length > 0 ? ` ${reason}` : "";
|
||||
return new Error(`WebSocket closed${codeText}${reasonText}`.trim());
|
||||
}
|
||||
return new Error("WebSocket closed");
|
||||
}
|
||||
|
||||
async function decodeWebSocketData(data: unknown): Promise<string | null> {
|
||||
if (typeof data === "string") return data;
|
||||
if (data instanceof ArrayBuffer) {
|
||||
return new TextDecoder().decode(new Uint8Array(data));
|
||||
}
|
||||
if (ArrayBuffer.isView(data)) {
|
||||
const view = data as ArrayBufferView;
|
||||
return new TextDecoder().decode(new Uint8Array(view.buffer, view.byteOffset, view.byteLength));
|
||||
}
|
||||
if (data && typeof data === "object" && "arrayBuffer" in data) {
|
||||
const blobLike = data as { arrayBuffer: () => Promise<ArrayBuffer> };
|
||||
const arrayBuffer = await blobLike.arrayBuffer();
|
||||
return new TextDecoder().decode(new Uint8Array(arrayBuffer));
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
async function* parseWebSocket(socket: WebSocketLike, signal?: AbortSignal): AsyncGenerator<Record<string, unknown>> {
|
||||
const queue: Record<string, unknown>[] = [];
|
||||
let pending: (() => void) | null = null;
|
||||
let done = false;
|
||||
let failed: Error | null = null;
|
||||
let sawCompletion = false;
|
||||
|
||||
const wake = () => {
|
||||
if (!pending) return;
|
||||
const resolve = pending;
|
||||
pending = null;
|
||||
resolve();
|
||||
};
|
||||
|
||||
const onMessage: WebSocketListener = (event) => {
|
||||
void (async () => {
|
||||
if (!event || typeof event !== "object" || !("data" in event)) return;
|
||||
const text = await decodeWebSocketData((event as { data?: unknown }).data);
|
||||
if (!text) return;
|
||||
try {
|
||||
const parsed = JSON.parse(text) as Record<string, unknown>;
|
||||
const type = typeof parsed.type === "string" ? parsed.type : "";
|
||||
if (type === "response.completed" || type === "response.done" || type === "response.incomplete") {
|
||||
sawCompletion = true;
|
||||
done = true;
|
||||
}
|
||||
queue.push(parsed);
|
||||
wake();
|
||||
} catch {}
|
||||
})();
|
||||
};
|
||||
|
||||
const onError: WebSocketListener = (event) => {
|
||||
failed = extractWebSocketError(event);
|
||||
done = true;
|
||||
wake();
|
||||
};
|
||||
|
||||
const onClose: WebSocketListener = (event) => {
|
||||
if (sawCompletion) {
|
||||
done = true;
|
||||
wake();
|
||||
return;
|
||||
}
|
||||
if (!failed) {
|
||||
failed = extractWebSocketCloseError(event);
|
||||
}
|
||||
done = true;
|
||||
wake();
|
||||
};
|
||||
|
||||
const onAbort = () => {
|
||||
failed = new Error("Request was aborted");
|
||||
done = true;
|
||||
wake();
|
||||
};
|
||||
|
||||
socket.addEventListener("message", onMessage);
|
||||
socket.addEventListener("error", onError);
|
||||
socket.addEventListener("close", onClose);
|
||||
signal?.addEventListener("abort", onAbort);
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
if (signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
if (queue.length > 0) {
|
||||
yield queue.shift()!;
|
||||
continue;
|
||||
}
|
||||
if (done) break;
|
||||
await new Promise<void>((resolve) => {
|
||||
pending = resolve;
|
||||
});
|
||||
}
|
||||
|
||||
if (failed) {
|
||||
throw failed;
|
||||
}
|
||||
if (!sawCompletion) {
|
||||
throw new Error("WebSocket stream closed before response.completed");
|
||||
}
|
||||
} finally {
|
||||
socket.removeEventListener("message", onMessage);
|
||||
socket.removeEventListener("error", onError);
|
||||
socket.removeEventListener("close", onClose);
|
||||
signal?.removeEventListener("abort", onAbort);
|
||||
}
|
||||
}
|
||||
|
||||
async function processWebSocketStream(
|
||||
url: string,
|
||||
body: RequestBody,
|
||||
headers: Headers,
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
model: Model<"openai-codex-responses">,
|
||||
onStart: () => void,
|
||||
options?: OpenAICodexResponsesOptions,
|
||||
): Promise<void> {
|
||||
const { socket, release } = await acquireWebSocket(url, headers, options?.sessionId, options?.signal);
|
||||
let keepConnection = true;
|
||||
try {
|
||||
socket.send(JSON.stringify({ type: "response.create", ...body }));
|
||||
onStart();
|
||||
stream.push({ type: "start", partial: output });
|
||||
await processResponsesStream(mapCodexEvents(parseWebSocket(socket, options?.signal)), output, stream, model);
|
||||
if (options?.signal?.aborted) {
|
||||
keepConnection = false;
|
||||
}
|
||||
} catch (error) {
|
||||
keepConnection = false;
|
||||
throw error;
|
||||
} finally {
|
||||
release({ keep: keepConnection });
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Handling
|
||||
// ============================================================================
|
||||
|
||||
async function parseErrorResponse(response: Response): Promise<{ message: string; friendlyMessage?: string }> {
|
||||
const raw = await response.text();
|
||||
let message = raw || response.statusText || "Request failed";
|
||||
let friendlyMessage: string | undefined;
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(raw) as {
|
||||
error?: { code?: string; type?: string; message?: string; plan_type?: string; resets_at?: number };
|
||||
};
|
||||
const err = parsed?.error;
|
||||
if (err) {
|
||||
const code = err.code || err.type || "";
|
||||
if (/usage_limit_reached|usage_not_included|rate_limit_exceeded/i.test(code) || response.status === 429) {
|
||||
const plan = err.plan_type ? ` (${err.plan_type.toLowerCase()} plan)` : "";
|
||||
const mins = err.resets_at
|
||||
? Math.max(0, Math.round((err.resets_at * 1000 - Date.now()) / 60000))
|
||||
: undefined;
|
||||
const when = mins !== undefined ? ` Try again in ~${mins} min.` : "";
|
||||
friendlyMessage = `You have hit your ChatGPT usage limit${plan}.${when}`.trim();
|
||||
}
|
||||
message = err.message || friendlyMessage || message;
|
||||
}
|
||||
} catch {}
|
||||
|
||||
return { message, friendlyMessage };
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Auth & Headers
|
||||
// ============================================================================
|
||||
|
||||
function extractAccountId(token: string): string {
|
||||
try {
|
||||
const parts = token.split(".");
|
||||
if (parts.length !== 3) throw new Error("Invalid token");
|
||||
const payload = JSON.parse(atob(parts[1]));
|
||||
const accountId = payload?.[JWT_CLAIM_PATH]?.chatgpt_account_id;
|
||||
if (!accountId) throw new Error("No account ID in token");
|
||||
return accountId;
|
||||
} catch {
|
||||
throw new Error("Failed to extract accountId from token");
|
||||
}
|
||||
}
|
||||
|
||||
function createCodexRequestId(): string {
|
||||
if (typeof globalThis.crypto?.randomUUID === "function") {
|
||||
return globalThis.crypto.randomUUID();
|
||||
}
|
||||
return `codex_${Date.now()}_${Math.random().toString(36).slice(2, 10)}`;
|
||||
}
|
||||
|
||||
function buildBaseCodexHeaders(
|
||||
initHeaders: Record<string, string> | undefined,
|
||||
additionalHeaders: Record<string, string> | undefined,
|
||||
accountId: string,
|
||||
token: string,
|
||||
): Headers {
|
||||
const headers = new Headers(initHeaders);
|
||||
for (const [key, value] of Object.entries(additionalHeaders || {})) {
|
||||
headers.set(key, value);
|
||||
}
|
||||
headers.set("Authorization", `Bearer ${token}`);
|
||||
headers.set("chatgpt-account-id", accountId);
|
||||
headers.set("originator", "pi");
|
||||
const userAgent = _os ? `pi (${_os.platform()} ${_os.release()}; ${_os.arch()})` : "pi (browser)";
|
||||
headers.set("User-Agent", userAgent);
|
||||
return headers;
|
||||
}
|
||||
|
||||
function buildSSEHeaders(
|
||||
initHeaders: Record<string, string> | undefined,
|
||||
additionalHeaders: Record<string, string> | undefined,
|
||||
accountId: string,
|
||||
token: string,
|
||||
sessionId?: string,
|
||||
): Headers {
|
||||
const headers = buildBaseCodexHeaders(initHeaders, additionalHeaders, accountId, token);
|
||||
headers.set("OpenAI-Beta", "responses=experimental");
|
||||
headers.set("accept", "text/event-stream");
|
||||
headers.set("content-type", "application/json");
|
||||
|
||||
if (sessionId) {
|
||||
headers.set("session_id", sessionId);
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
function buildWebSocketHeaders(
|
||||
initHeaders: Record<string, string> | undefined,
|
||||
additionalHeaders: Record<string, string> | undefined,
|
||||
accountId: string,
|
||||
token: string,
|
||||
requestId: string,
|
||||
): Headers {
|
||||
const headers = buildBaseCodexHeaders(initHeaders, additionalHeaders, accountId, token);
|
||||
headers.delete("accept");
|
||||
headers.delete("content-type");
|
||||
headers.delete("OpenAI-Beta");
|
||||
headers.delete("openai-beta");
|
||||
headers.set("OpenAI-Beta", OPENAI_BETA_RESPONSES_WEBSOCKETS);
|
||||
headers.set("x-client-request-id", requestId);
|
||||
headers.set("session_id", requestId);
|
||||
return headers;
|
||||
}
|
||||
866
packages/ai/src/providers/openai-completions.ts
Normal file
866
packages/ai/src/providers/openai-completions.ts
Normal file
|
|
@ -0,0 +1,866 @@
|
|||
import OpenAI from "openai";
|
||||
import type {
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionContentPart,
|
||||
ChatCompletionContentPartImage,
|
||||
ChatCompletionContentPartText,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
} from "openai/resources/chat/completions.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { calculateCost, supportsXhigh } from "../models.js";
|
||||
import type {
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Message,
|
||||
Model,
|
||||
OpenAICompletionsCompat,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolResultMessage,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./github-copilot-headers.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
/**
|
||||
* Check if conversation messages contain tool calls or tool results.
|
||||
* This is needed because Anthropic (via proxy) requires the tools param
|
||||
* to be present when messages include tool_calls or tool role messages.
|
||||
*/
|
||||
function hasToolHistory(messages: Message[]): boolean {
|
||||
for (const msg of messages) {
|
||||
if (msg.role === "toolResult") {
|
||||
return true;
|
||||
}
|
||||
if (msg.role === "assistant") {
|
||||
if (msg.content.some((block) => block.type === "toolCall")) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
export interface OpenAICompletionsOptions extends StreamOptions {
|
||||
toolChoice?: "auto" | "none" | "required" | { type: "function"; function: { name: string } };
|
||||
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
}
|
||||
|
||||
export const streamOpenAICompletions: StreamFunction<"openai-completions", OpenAICompletionsOptions> = (
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
options?: OpenAICompletionsOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const client = createClient(model, context, apiKey, options?.headers);
|
||||
let params = buildParams(model, context, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming;
|
||||
}
|
||||
const openaiStream = await client.chat.completions.create(params, { signal: options?.signal });
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
let currentBlock: TextContent | ThinkingContent | (ToolCall & { partialArgs?: string }) | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
const finishCurrentBlock = (block?: typeof currentBlock) => {
|
||||
if (block) {
|
||||
if (block.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: block.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: block.thinking,
|
||||
partial: output,
|
||||
});
|
||||
} else if (block.type === "toolCall") {
|
||||
block.arguments = parseStreamingJson(block.partialArgs);
|
||||
delete block.partialArgs;
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: blockIndex(),
|
||||
toolCall: block,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for await (const chunk of openaiStream) {
|
||||
if (!chunk || typeof chunk !== "object") continue;
|
||||
|
||||
// OpenAI documents ChatCompletionChunk.id as the unique chat completion identifier,
|
||||
// and each chunk in a streamed completion carries the same id.
|
||||
output.responseId ||= chunk.id;
|
||||
if (chunk.usage) {
|
||||
output.usage = parseChunkUsage(chunk.usage, model);
|
||||
}
|
||||
|
||||
const choice = Array.isArray(chunk.choices) ? chunk.choices[0] : undefined;
|
||||
if (!choice) continue;
|
||||
|
||||
// Fallback: some providers (e.g., Moonshot) return usage
|
||||
// in choice.usage instead of the standard chunk.usage
|
||||
if (!chunk.usage && (choice as any).usage) {
|
||||
output.usage = parseChunkUsage((choice as any).usage, model);
|
||||
}
|
||||
|
||||
if (choice.finish_reason) {
|
||||
const finishReasonResult = mapStopReason(choice.finish_reason);
|
||||
output.stopReason = finishReasonResult.stopReason;
|
||||
if (finishReasonResult.errorMessage) {
|
||||
output.errorMessage = finishReasonResult.errorMessage;
|
||||
}
|
||||
}
|
||||
|
||||
if (choice.delta) {
|
||||
if (
|
||||
choice.delta.content !== null &&
|
||||
choice.delta.content !== undefined &&
|
||||
choice.delta.content.length > 0
|
||||
) {
|
||||
if (!currentBlock || currentBlock.type !== "text") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
|
||||
if (currentBlock.type === "text") {
|
||||
currentBlock.text += choice.delta.content;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: choice.delta.content,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Some endpoints return reasoning in reasoning_content (llama.cpp),
|
||||
// or reasoning (other openai compatible endpoints)
|
||||
// Use the first non-empty reasoning field to avoid duplication
|
||||
// (e.g., chutes.ai returns both reasoning_content and reasoning with same content)
|
||||
const reasoningFields = ["reasoning_content", "reasoning", "reasoning_text"];
|
||||
let foundReasoningField: string | null = null;
|
||||
for (const field of reasoningFields) {
|
||||
if (
|
||||
(choice.delta as any)[field] !== null &&
|
||||
(choice.delta as any)[field] !== undefined &&
|
||||
(choice.delta as any)[field].length > 0
|
||||
) {
|
||||
if (!foundReasoningField) {
|
||||
foundReasoningField = field;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (foundReasoningField) {
|
||||
if (!currentBlock || currentBlock.type !== "thinking") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: foundReasoningField,
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
|
||||
if (currentBlock.type === "thinking") {
|
||||
const delta = (choice.delta as any)[foundReasoningField];
|
||||
currentBlock.thinking += delta;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (choice?.delta?.tool_calls) {
|
||||
for (const toolCall of choice.delta.tool_calls) {
|
||||
if (
|
||||
!currentBlock ||
|
||||
currentBlock.type !== "toolCall" ||
|
||||
(toolCall.id && currentBlock.id !== toolCall.id)
|
||||
) {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = {
|
||||
type: "toolCall",
|
||||
id: toolCall.id || "",
|
||||
name: toolCall.function?.name || "",
|
||||
arguments: {},
|
||||
partialArgs: "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
|
||||
if (currentBlock.type === "toolCall") {
|
||||
if (toolCall.id) currentBlock.id = toolCall.id;
|
||||
if (toolCall.function?.name) currentBlock.name = toolCall.function.name;
|
||||
let delta = "";
|
||||
if (toolCall.function?.arguments) {
|
||||
delta = toolCall.function.arguments;
|
||||
currentBlock.partialArgs += toolCall.function.arguments;
|
||||
currentBlock.arguments = parseStreamingJson(currentBlock.partialArgs);
|
||||
}
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const reasoningDetails = (choice.delta as any).reasoning_details;
|
||||
if (reasoningDetails && Array.isArray(reasoningDetails)) {
|
||||
for (const detail of reasoningDetails) {
|
||||
if (detail.type === "reasoning.encrypted" && detail.id && detail.data) {
|
||||
const matchingToolCall = output.content.find(
|
||||
(b) => b.type === "toolCall" && b.id === detail.id,
|
||||
) as ToolCall | undefined;
|
||||
if (matchingToolCall) {
|
||||
matchingToolCall.thoughtSignature = JSON.stringify(detail);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
finishCurrentBlock(currentBlock);
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted") {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
if (output.stopReason === "error") {
|
||||
throw new Error(output.errorMessage || "Provider returned an error stop reason");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) delete (block as any).index;
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
// Some providers via OpenRouter give additional information in this field.
|
||||
const rawMetadata = (error as any)?.error?.metadata?.raw;
|
||||
if (rawMetadata) output.errorMessage += `\n${rawMetadata}`;
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleOpenAICompletions: StreamFunction<"openai-completions", SimpleStreamOptions> = (
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const reasoningEffort = supportsXhigh(model) ? options?.reasoning : clampReasoning(options?.reasoning);
|
||||
const toolChoice = (options as OpenAICompletionsOptions | undefined)?.toolChoice;
|
||||
|
||||
return streamOpenAICompletions(model, context, {
|
||||
...base,
|
||||
reasoningEffort,
|
||||
toolChoice,
|
||||
} satisfies OpenAICompletionsOptions);
|
||||
};
|
||||
|
||||
function createClient(
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
apiKey?: string,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.OPENAI_API_KEY;
|
||||
}
|
||||
|
||||
const headers = { ...model.headers };
|
||||
if (model.provider === "github-copilot") {
|
||||
const hasImages = hasCopilotVisionInput(context.messages);
|
||||
const copilotHeaders = buildCopilotDynamicHeaders({
|
||||
messages: context.messages,
|
||||
hasImages,
|
||||
});
|
||||
Object.assign(headers, copilotHeaders);
|
||||
}
|
||||
|
||||
// Merge options headers last so they can override defaults
|
||||
if (optionsHeaders) {
|
||||
Object.assign(headers, optionsHeaders);
|
||||
}
|
||||
|
||||
return new OpenAI({
|
||||
apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: headers,
|
||||
});
|
||||
}
|
||||
|
||||
function buildParams(model: Model<"openai-completions">, context: Context, options?: OpenAICompletionsOptions) {
|
||||
const compat = getCompat(model);
|
||||
const messages = convertMessages(model, context, compat);
|
||||
maybeAddOpenRouterAnthropicCacheControl(model, messages);
|
||||
|
||||
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
messages,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
if (compat.supportsUsageInStreaming !== false) {
|
||||
(params as any).stream_options = { include_usage: true };
|
||||
}
|
||||
|
||||
if (compat.supportsStore) {
|
||||
params.store = false;
|
||||
}
|
||||
|
||||
if (options?.maxTokens) {
|
||||
if (compat.maxTokensField === "max_tokens") {
|
||||
(params as any).max_tokens = options.maxTokens;
|
||||
} else {
|
||||
params.max_completion_tokens = options.maxTokens;
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertTools(context.tools, compat);
|
||||
} else if (hasToolHistory(context.messages)) {
|
||||
// Anthropic (via LiteLLM/proxy) requires tools param when conversation has tool_calls/tool_results
|
||||
params.tools = [];
|
||||
}
|
||||
|
||||
if (options?.toolChoice) {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
|
||||
if (compat.thinkingFormat === "zai" && model.reasoning) {
|
||||
(params as any).enable_thinking = !!options?.reasoningEffort;
|
||||
} else if (compat.thinkingFormat === "qwen" && model.reasoning) {
|
||||
(params as any).enable_thinking = !!options?.reasoningEffort;
|
||||
} else if (compat.thinkingFormat === "qwen-chat-template" && model.reasoning) {
|
||||
(params as any).chat_template_kwargs = { enable_thinking: !!options?.reasoningEffort };
|
||||
} else if (compat.thinkingFormat === "openrouter" && model.reasoning) {
|
||||
// OpenRouter normalizes reasoning across providers via a nested reasoning object.
|
||||
const openRouterParams = params as typeof params & { reasoning?: { effort?: string } };
|
||||
if (options?.reasoningEffort) {
|
||||
openRouterParams.reasoning = {
|
||||
effort: mapReasoningEffort(options.reasoningEffort, compat.reasoningEffortMap),
|
||||
};
|
||||
} else {
|
||||
openRouterParams.reasoning = { effort: "none" };
|
||||
}
|
||||
} else if (options?.reasoningEffort && model.reasoning && compat.supportsReasoningEffort) {
|
||||
// OpenAI-style reasoning_effort
|
||||
(params as any).reasoning_effort = mapReasoningEffort(options.reasoningEffort, compat.reasoningEffortMap);
|
||||
}
|
||||
|
||||
// OpenRouter provider routing preferences
|
||||
if (model.baseUrl.includes("openrouter.ai") && model.compat?.openRouterRouting) {
|
||||
(params as any).provider = model.compat.openRouterRouting;
|
||||
}
|
||||
|
||||
// Vercel AI Gateway provider routing preferences
|
||||
if (model.baseUrl.includes("ai-gateway.vercel.sh") && model.compat?.vercelGatewayRouting) {
|
||||
const routing = model.compat.vercelGatewayRouting;
|
||||
if (routing.only || routing.order) {
|
||||
const gatewayOptions: Record<string, string[]> = {};
|
||||
if (routing.only) gatewayOptions.only = routing.only;
|
||||
if (routing.order) gatewayOptions.order = routing.order;
|
||||
(params as any).providerOptions = { gateway: gatewayOptions };
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
function mapReasoningEffort(
|
||||
effort: NonNullable<OpenAICompletionsOptions["reasoningEffort"]>,
|
||||
reasoningEffortMap: Partial<Record<NonNullable<OpenAICompletionsOptions["reasoningEffort"]>, string>>,
|
||||
): string {
|
||||
return reasoningEffortMap[effort] ?? effort;
|
||||
}
|
||||
|
||||
function maybeAddOpenRouterAnthropicCacheControl(
|
||||
model: Model<"openai-completions">,
|
||||
messages: ChatCompletionMessageParam[],
|
||||
): void {
|
||||
if (model.provider !== "openrouter" || !model.id.startsWith("anthropic/")) return;
|
||||
|
||||
// Anthropic-style caching requires cache_control on a text part. Add a breakpoint
|
||||
// on the last user/assistant message (walking backwards until we find text content).
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
const msg = messages[i];
|
||||
if (msg.role !== "user" && msg.role !== "assistant") continue;
|
||||
|
||||
const content = msg.content;
|
||||
if (typeof content === "string") {
|
||||
msg.content = [
|
||||
Object.assign({ type: "text" as const, text: content }, { cache_control: { type: "ephemeral" } }),
|
||||
];
|
||||
return;
|
||||
}
|
||||
|
||||
if (!Array.isArray(content)) continue;
|
||||
|
||||
// Find last text part and add cache_control
|
||||
for (let j = content.length - 1; j >= 0; j--) {
|
||||
const part = content[j];
|
||||
if (part?.type === "text") {
|
||||
Object.assign(part, { cache_control: { type: "ephemeral" } });
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function convertMessages(
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
compat: Required<OpenAICompletionsCompat>,
|
||||
): ChatCompletionMessageParam[] {
|
||||
const params: ChatCompletionMessageParam[] = [];
|
||||
|
||||
const normalizeToolCallId = (id: string): string => {
|
||||
// Handle pipe-separated IDs from OpenAI Responses API
|
||||
// Format: {call_id}|{id} where {id} can be 400+ chars with special chars (+, /, =)
|
||||
// These come from providers like github-copilot, openai-codex, opencode
|
||||
// Extract just the call_id part and normalize it
|
||||
if (id.includes("|")) {
|
||||
const [callId] = id.split("|");
|
||||
// Sanitize to allowed chars and truncate to 40 chars (OpenAI limit)
|
||||
return callId.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 40);
|
||||
}
|
||||
|
||||
if (model.provider === "openai") return id.length > 40 ? id.slice(0, 40) : id;
|
||||
return id;
|
||||
};
|
||||
|
||||
const transformedMessages = transformMessages(context.messages, model, (id) => normalizeToolCallId(id));
|
||||
|
||||
if (context.systemPrompt) {
|
||||
const useDeveloperRole = model.reasoning && compat.supportsDeveloperRole;
|
||||
const role = useDeveloperRole ? "developer" : "system";
|
||||
params.push({ role: role, content: sanitizeSurrogates(context.systemPrompt) });
|
||||
}
|
||||
|
||||
let lastRole: string | null = null;
|
||||
|
||||
for (let i = 0; i < transformedMessages.length; i++) {
|
||||
const msg = transformedMessages[i];
|
||||
// Some providers don't allow user messages directly after tool results
|
||||
// Insert a synthetic assistant message to bridge the gap
|
||||
if (compat.requiresAssistantAfterToolResult && lastRole === "toolResult" && msg.role === "user") {
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: "I have processed the tool results.",
|
||||
});
|
||||
}
|
||||
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: sanitizeSurrogates(msg.content),
|
||||
});
|
||||
} else {
|
||||
const content: ChatCompletionContentPart[] = msg.content.map((item): ChatCompletionContentPart => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(item.text),
|
||||
} satisfies ChatCompletionContentPartText;
|
||||
} else {
|
||||
return {
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: `data:${item.mimeType};base64,${item.data}`,
|
||||
},
|
||||
} satisfies ChatCompletionContentPartImage;
|
||||
}
|
||||
});
|
||||
const filteredContent = !model.input.includes("image")
|
||||
? content.filter((c) => c.type !== "image_url")
|
||||
: content;
|
||||
if (filteredContent.length === 0) continue;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: filteredContent,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
// Some providers don't accept null content, use empty string instead
|
||||
const assistantMsg: ChatCompletionAssistantMessageParam = {
|
||||
role: "assistant",
|
||||
content: compat.requiresAssistantAfterToolResult ? "" : null,
|
||||
};
|
||||
|
||||
const textBlocks = msg.content.filter((b) => b.type === "text") as TextContent[];
|
||||
// Filter out empty text blocks to avoid API validation errors
|
||||
const nonEmptyTextBlocks = textBlocks.filter((b) => b.text && b.text.trim().length > 0);
|
||||
if (nonEmptyTextBlocks.length > 0) {
|
||||
// Always send assistant content as a plain string (OpenAI Chat Completions
|
||||
// API standard format). Sending as an array of {type:"text", text:"..."}
|
||||
// objects is non-standard and causes some models (e.g. DeepSeek V3.2 via
|
||||
// NVIDIA NIM) to mirror the content-block structure literally in their
|
||||
// output, producing recursive nesting like [{'type':'text','text':'[{...}]'}].
|
||||
assistantMsg.content = nonEmptyTextBlocks.map((b) => sanitizeSurrogates(b.text)).join("");
|
||||
}
|
||||
|
||||
// Handle thinking blocks
|
||||
const thinkingBlocks = msg.content.filter((b) => b.type === "thinking") as ThinkingContent[];
|
||||
// Filter out empty thinking blocks to avoid API validation errors
|
||||
const nonEmptyThinkingBlocks = thinkingBlocks.filter((b) => b.thinking && b.thinking.trim().length > 0);
|
||||
if (nonEmptyThinkingBlocks.length > 0) {
|
||||
if (compat.requiresThinkingAsText) {
|
||||
// Convert thinking blocks to plain text (no tags to avoid model mimicking them)
|
||||
const thinkingText = nonEmptyThinkingBlocks.map((b) => b.thinking).join("\n\n");
|
||||
const textContent = assistantMsg.content as Array<{ type: "text"; text: string }> | null;
|
||||
if (textContent) {
|
||||
textContent.unshift({ type: "text", text: thinkingText });
|
||||
} else {
|
||||
assistantMsg.content = [{ type: "text", text: thinkingText }];
|
||||
}
|
||||
} else {
|
||||
// Use the signature from the first thinking block if available (for llama.cpp server + gpt-oss)
|
||||
const signature = nonEmptyThinkingBlocks[0].thinkingSignature;
|
||||
if (signature && signature.length > 0) {
|
||||
(assistantMsg as any)[signature] = nonEmptyThinkingBlocks.map((b) => b.thinking).join("\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const toolCalls = msg.content.filter((b) => b.type === "toolCall") as ToolCall[];
|
||||
if (toolCalls.length > 0) {
|
||||
assistantMsg.tool_calls = toolCalls.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: "function" as const,
|
||||
function: {
|
||||
name: tc.name,
|
||||
arguments: JSON.stringify(tc.arguments),
|
||||
},
|
||||
}));
|
||||
const reasoningDetails = toolCalls
|
||||
.filter((tc) => tc.thoughtSignature)
|
||||
.map((tc) => {
|
||||
try {
|
||||
return JSON.parse(tc.thoughtSignature!);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.filter(Boolean);
|
||||
if (reasoningDetails.length > 0) {
|
||||
(assistantMsg as any).reasoning_details = reasoningDetails;
|
||||
}
|
||||
}
|
||||
// Skip assistant messages that have no content and no tool calls.
|
||||
// Some providers require "either content or tool_calls, but not none".
|
||||
// Other providers also don't accept empty assistant messages.
|
||||
// This handles aborted assistant responses that got no content.
|
||||
const content = assistantMsg.content;
|
||||
const hasContent =
|
||||
content !== null &&
|
||||
content !== undefined &&
|
||||
(typeof content === "string" ? content.length > 0 : content.length > 0);
|
||||
if (!hasContent && !assistantMsg.tool_calls) {
|
||||
continue;
|
||||
}
|
||||
params.push(assistantMsg);
|
||||
} else if (msg.role === "toolResult") {
|
||||
const imageBlocks: Array<{ type: "image_url"; image_url: { url: string } }> = [];
|
||||
let j = i;
|
||||
|
||||
for (; j < transformedMessages.length && transformedMessages[j].role === "toolResult"; j++) {
|
||||
const toolMsg = transformedMessages[j] as ToolResultMessage;
|
||||
|
||||
// Extract text and image content
|
||||
const textResult = toolMsg.content
|
||||
.filter((c) => c.type === "text")
|
||||
.map((c) => (c as any).text)
|
||||
.join("\n");
|
||||
const hasImages = toolMsg.content.some((c) => c.type === "image");
|
||||
|
||||
// Always send tool result with text (or placeholder if only images)
|
||||
const hasText = textResult.length > 0;
|
||||
// Some providers require the 'name' field in tool results
|
||||
const toolResultMsg: ChatCompletionToolMessageParam = {
|
||||
role: "tool",
|
||||
content: sanitizeSurrogates(hasText ? textResult : "(see attached image)"),
|
||||
tool_call_id: toolMsg.toolCallId,
|
||||
};
|
||||
if (compat.requiresToolResultName && toolMsg.toolName) {
|
||||
(toolResultMsg as any).name = toolMsg.toolName;
|
||||
}
|
||||
params.push(toolResultMsg);
|
||||
|
||||
if (hasImages && model.input.includes("image")) {
|
||||
for (const block of toolMsg.content) {
|
||||
if (block.type === "image") {
|
||||
imageBlocks.push({
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: `data:${(block as any).mimeType};base64,${(block as any).data}`,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i = j - 1;
|
||||
|
||||
if (imageBlocks.length > 0) {
|
||||
if (compat.requiresAssistantAfterToolResult) {
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: "I have processed the tool results.",
|
||||
});
|
||||
}
|
||||
|
||||
params.push({
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Attached image(s) from tool result:",
|
||||
},
|
||||
...imageBlocks,
|
||||
],
|
||||
});
|
||||
lastRole = "user";
|
||||
} else {
|
||||
lastRole = "toolResult";
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
lastRole = msg.role;
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
function convertTools(
|
||||
tools: Tool[],
|
||||
compat: Required<OpenAICompletionsCompat>,
|
||||
): OpenAI.Chat.Completions.ChatCompletionTool[] {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters as any, // TypeBox already generates JSON Schema
|
||||
// Only include strict if provider supports it. Some reject unknown fields.
|
||||
...(compat.supportsStrictMode !== false && { strict: false }),
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
function parseChunkUsage(
|
||||
rawUsage: {
|
||||
prompt_tokens?: number;
|
||||
completion_tokens?: number;
|
||||
prompt_tokens_details?: { cached_tokens?: number };
|
||||
completion_tokens_details?: { reasoning_tokens?: number };
|
||||
},
|
||||
model: Model<"openai-completions">,
|
||||
): AssistantMessage["usage"] {
|
||||
const cachedTokens = rawUsage.prompt_tokens_details?.cached_tokens || 0;
|
||||
const reasoningTokens = rawUsage.completion_tokens_details?.reasoning_tokens || 0;
|
||||
// OpenAI includes cached tokens in prompt_tokens, so subtract to get non-cached input
|
||||
const input = (rawUsage.prompt_tokens || 0) - cachedTokens;
|
||||
// Compute totalTokens ourselves since we add reasoning_tokens to output
|
||||
// and some providers (e.g., Groq) don't include them in total_tokens
|
||||
const outputTokens = (rawUsage.completion_tokens || 0) + reasoningTokens;
|
||||
const usage: AssistantMessage["usage"] = {
|
||||
input,
|
||||
output: outputTokens,
|
||||
cacheRead: cachedTokens,
|
||||
cacheWrite: 0,
|
||||
totalTokens: input + outputTokens + cachedTokens,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
calculateCost(model, usage);
|
||||
return usage;
|
||||
}
|
||||
|
||||
function mapStopReason(reason: ChatCompletionChunk.Choice["finish_reason"] | string): {
|
||||
stopReason: StopReason;
|
||||
errorMessage?: string;
|
||||
} {
|
||||
if (reason === null) return { stopReason: "stop" };
|
||||
switch (reason) {
|
||||
case "stop":
|
||||
case "end":
|
||||
return { stopReason: "stop" };
|
||||
case "length":
|
||||
return { stopReason: "length" };
|
||||
case "function_call":
|
||||
case "tool_calls":
|
||||
return { stopReason: "toolUse" };
|
||||
case "content_filter":
|
||||
return { stopReason: "error", errorMessage: "Provider finish_reason: content_filter" };
|
||||
case "network_error":
|
||||
return { stopReason: "error", errorMessage: "Provider finish_reason: network_error" };
|
||||
default:
|
||||
return {
|
||||
stopReason: "error",
|
||||
errorMessage: `Provider finish_reason: ${reason}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect compatibility settings from provider and baseUrl for known providers.
|
||||
* Provider takes precedence over URL-based detection since it's explicitly configured.
|
||||
* Returns a fully resolved OpenAICompletionsCompat object with all fields set.
|
||||
*/
|
||||
function detectCompat(model: Model<"openai-completions">): Required<OpenAICompletionsCompat> {
|
||||
const provider = model.provider;
|
||||
const baseUrl = model.baseUrl;
|
||||
|
||||
const isZai = provider === "zai" || baseUrl.includes("api.z.ai");
|
||||
|
||||
const isNonStandard =
|
||||
provider === "cerebras" ||
|
||||
baseUrl.includes("cerebras.ai") ||
|
||||
provider === "xai" ||
|
||||
baseUrl.includes("api.x.ai") ||
|
||||
baseUrl.includes("chutes.ai") ||
|
||||
baseUrl.includes("deepseek.com") ||
|
||||
isZai ||
|
||||
provider === "opencode" ||
|
||||
baseUrl.includes("opencode.ai");
|
||||
|
||||
const useMaxTokens = baseUrl.includes("chutes.ai");
|
||||
|
||||
const isGrok = provider === "xai" || baseUrl.includes("api.x.ai");
|
||||
const isGroq = provider === "groq" || baseUrl.includes("groq.com");
|
||||
|
||||
const reasoningEffortMap =
|
||||
isGroq && model.id === "qwen/qwen3-32b"
|
||||
? {
|
||||
minimal: "default",
|
||||
low: "default",
|
||||
medium: "default",
|
||||
high: "default",
|
||||
xhigh: "default",
|
||||
}
|
||||
: {};
|
||||
return {
|
||||
supportsStore: !isNonStandard,
|
||||
supportsDeveloperRole: !isNonStandard,
|
||||
supportsReasoningEffort: !isGrok && !isZai,
|
||||
reasoningEffortMap,
|
||||
supportsUsageInStreaming: true,
|
||||
maxTokensField: useMaxTokens ? "max_tokens" : "max_completion_tokens",
|
||||
requiresToolResultName: false,
|
||||
requiresAssistantAfterToolResult: false,
|
||||
requiresThinkingAsText: false,
|
||||
thinkingFormat: isZai
|
||||
? "zai"
|
||||
: provider === "openrouter" || baseUrl.includes("openrouter.ai")
|
||||
? "openrouter"
|
||||
: "openai",
|
||||
openRouterRouting: {},
|
||||
vercelGatewayRouting: {},
|
||||
supportsStrictMode: true,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get resolved compatibility settings for a model.
|
||||
* Uses explicit model.compat if provided, otherwise auto-detects from provider/URL.
|
||||
*/
|
||||
function getCompat(model: Model<"openai-completions">): Required<OpenAICompletionsCompat> {
|
||||
const detected = detectCompat(model);
|
||||
if (!model.compat) return detected;
|
||||
|
||||
return {
|
||||
supportsStore: model.compat.supportsStore ?? detected.supportsStore,
|
||||
supportsDeveloperRole: model.compat.supportsDeveloperRole ?? detected.supportsDeveloperRole,
|
||||
supportsReasoningEffort: model.compat.supportsReasoningEffort ?? detected.supportsReasoningEffort,
|
||||
reasoningEffortMap: model.compat.reasoningEffortMap ?? detected.reasoningEffortMap,
|
||||
supportsUsageInStreaming: model.compat.supportsUsageInStreaming ?? detected.supportsUsageInStreaming,
|
||||
maxTokensField: model.compat.maxTokensField ?? detected.maxTokensField,
|
||||
requiresToolResultName: model.compat.requiresToolResultName ?? detected.requiresToolResultName,
|
||||
requiresAssistantAfterToolResult:
|
||||
model.compat.requiresAssistantAfterToolResult ?? detected.requiresAssistantAfterToolResult,
|
||||
requiresThinkingAsText: model.compat.requiresThinkingAsText ?? detected.requiresThinkingAsText,
|
||||
thinkingFormat: model.compat.thinkingFormat ?? detected.thinkingFormat,
|
||||
openRouterRouting: model.compat.openRouterRouting ?? {},
|
||||
vercelGatewayRouting: model.compat.vercelGatewayRouting ?? detected.vercelGatewayRouting,
|
||||
supportsStrictMode: model.compat.supportsStrictMode ?? detected.supportsStrictMode,
|
||||
};
|
||||
}
|
||||
513
packages/ai/src/providers/openai-responses-shared.ts
Normal file
513
packages/ai/src/providers/openai-responses-shared.ts
Normal file
|
|
@ -0,0 +1,513 @@
|
|||
import type OpenAI from "openai";
|
||||
import type {
|
||||
Tool as OpenAITool,
|
||||
ResponseCreateParamsStreaming,
|
||||
ResponseFunctionCallOutputItemList,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseInput,
|
||||
ResponseInputContent,
|
||||
ResponseInputImage,
|
||||
ResponseInputText,
|
||||
ResponseOutputMessage,
|
||||
ResponseReasoningItem,
|
||||
ResponseStreamEvent,
|
||||
} from "openai/resources/responses/responses.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
ImageContent,
|
||||
Model,
|
||||
StopReason,
|
||||
TextContent,
|
||||
TextSignatureV1,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
Usage,
|
||||
} from "../types.js";
|
||||
import type { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { shortHash } from "../utils/hash.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
// =============================================================================
|
||||
// Utilities
|
||||
// =============================================================================
|
||||
|
||||
function encodeTextSignatureV1(id: string, phase?: TextSignatureV1["phase"]): string {
|
||||
const payload: TextSignatureV1 = { v: 1, id };
|
||||
if (phase) payload.phase = phase;
|
||||
return JSON.stringify(payload);
|
||||
}
|
||||
|
||||
function parseTextSignature(
|
||||
signature: string | undefined,
|
||||
): { id: string; phase?: TextSignatureV1["phase"] } | undefined {
|
||||
if (!signature) return undefined;
|
||||
if (signature.startsWith("{")) {
|
||||
try {
|
||||
const parsed = JSON.parse(signature) as Partial<TextSignatureV1>;
|
||||
if (parsed.v === 1 && typeof parsed.id === "string") {
|
||||
if (parsed.phase === "commentary" || parsed.phase === "final_answer") {
|
||||
return { id: parsed.id, phase: parsed.phase };
|
||||
}
|
||||
return { id: parsed.id };
|
||||
}
|
||||
} catch {
|
||||
// Fall through to legacy plain-string handling.
|
||||
}
|
||||
}
|
||||
return { id: signature };
|
||||
}
|
||||
|
||||
export interface OpenAIResponsesStreamOptions {
|
||||
serviceTier?: ResponseCreateParamsStreaming["service_tier"];
|
||||
applyServiceTierPricing?: (
|
||||
usage: Usage,
|
||||
serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined,
|
||||
) => void;
|
||||
}
|
||||
|
||||
export interface ConvertResponsesMessagesOptions {
|
||||
includeSystemPrompt?: boolean;
|
||||
}
|
||||
|
||||
export interface ConvertResponsesToolsOptions {
|
||||
strict?: boolean | null;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Message conversion
|
||||
// =============================================================================
|
||||
|
||||
export function convertResponsesMessages<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
allowedToolCallProviders: ReadonlySet<string>,
|
||||
options?: ConvertResponsesMessagesOptions,
|
||||
): ResponseInput {
|
||||
const messages: ResponseInput = [];
|
||||
|
||||
const normalizeIdPart = (part: string): string => {
|
||||
const sanitized = part.replace(/[^a-zA-Z0-9_-]/g, "_");
|
||||
const normalized = sanitized.length > 64 ? sanitized.slice(0, 64) : sanitized;
|
||||
return normalized.replace(/_+$/, "");
|
||||
};
|
||||
|
||||
const buildForeignResponsesItemId = (itemId: string): string => {
|
||||
const normalized = `fc_${shortHash(itemId)}`;
|
||||
return normalized.length > 64 ? normalized.slice(0, 64) : normalized;
|
||||
};
|
||||
|
||||
const normalizeToolCallId = (id: string, _targetModel: Model<TApi>, source: AssistantMessage): string => {
|
||||
if (!allowedToolCallProviders.has(model.provider)) return normalizeIdPart(id);
|
||||
if (!id.includes("|")) return normalizeIdPart(id);
|
||||
const [callId, itemId] = id.split("|");
|
||||
const normalizedCallId = normalizeIdPart(callId);
|
||||
const isForeignToolCall = source.provider !== model.provider || source.api !== model.api;
|
||||
let normalizedItemId = isForeignToolCall ? buildForeignResponsesItemId(itemId) : normalizeIdPart(itemId);
|
||||
// OpenAI Responses API requires item id to start with "fc"
|
||||
if (!normalizedItemId.startsWith("fc_")) {
|
||||
normalizedItemId = normalizeIdPart(`fc_${normalizedItemId}`);
|
||||
}
|
||||
return `${normalizedCallId}|${normalizedItemId}`;
|
||||
};
|
||||
|
||||
const transformedMessages = transformMessages(context.messages, model, normalizeToolCallId);
|
||||
|
||||
const includeSystemPrompt = options?.includeSystemPrompt ?? true;
|
||||
if (includeSystemPrompt && context.systemPrompt) {
|
||||
const role = model.reasoning ? "developer" : "system";
|
||||
messages.push({
|
||||
role,
|
||||
content: sanitizeSurrogates(context.systemPrompt),
|
||||
});
|
||||
}
|
||||
|
||||
let msgIndex = 0;
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: [{ type: "input_text", text: sanitizeSurrogates(msg.content) }],
|
||||
});
|
||||
} else {
|
||||
const content: ResponseInputContent[] = msg.content.map((item): ResponseInputContent => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "input_text",
|
||||
text: sanitizeSurrogates(item.text),
|
||||
} satisfies ResponseInputText;
|
||||
}
|
||||
return {
|
||||
type: "input_image",
|
||||
detail: "auto",
|
||||
image_url: `data:${item.mimeType};base64,${item.data}`,
|
||||
} satisfies ResponseInputImage;
|
||||
});
|
||||
const filteredContent = !model.input.includes("image")
|
||||
? content.filter((c) => c.type !== "input_image")
|
||||
: content;
|
||||
if (filteredContent.length === 0) continue;
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: filteredContent,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const output: ResponseInput = [];
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
const isDifferentModel =
|
||||
assistantMsg.model !== model.id &&
|
||||
assistantMsg.provider === model.provider &&
|
||||
assistantMsg.api === model.api;
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "thinking") {
|
||||
if (block.thinkingSignature) {
|
||||
const reasoningItem = JSON.parse(block.thinkingSignature) as ResponseReasoningItem;
|
||||
output.push(reasoningItem);
|
||||
}
|
||||
} else if (block.type === "text") {
|
||||
const textBlock = block as TextContent;
|
||||
const parsedSignature = parseTextSignature(textBlock.textSignature);
|
||||
// OpenAI requires id to be max 64 characters
|
||||
let msgId = parsedSignature?.id;
|
||||
if (!msgId) {
|
||||
msgId = `msg_${msgIndex}`;
|
||||
} else if (msgId.length > 64) {
|
||||
msgId = `msg_${shortHash(msgId)}`;
|
||||
}
|
||||
output.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [{ type: "output_text", text: sanitizeSurrogates(textBlock.text), annotations: [] }],
|
||||
status: "completed",
|
||||
id: msgId,
|
||||
phase: parsedSignature?.phase,
|
||||
} satisfies ResponseOutputMessage);
|
||||
} else if (block.type === "toolCall") {
|
||||
const toolCall = block as ToolCall;
|
||||
const [callId, itemIdRaw] = toolCall.id.split("|");
|
||||
let itemId: string | undefined = itemIdRaw;
|
||||
|
||||
// For different-model messages, set id to undefined to avoid pairing validation.
|
||||
// OpenAI tracks which fc_xxx IDs were paired with rs_xxx reasoning items.
|
||||
// By omitting the id, we avoid triggering that validation (like cross-provider does).
|
||||
if (isDifferentModel && itemId?.startsWith("fc_")) {
|
||||
itemId = undefined;
|
||||
}
|
||||
|
||||
output.push({
|
||||
type: "function_call",
|
||||
id: itemId,
|
||||
call_id: callId,
|
||||
name: toolCall.name,
|
||||
arguments: JSON.stringify(toolCall.arguments),
|
||||
});
|
||||
}
|
||||
}
|
||||
if (output.length === 0) continue;
|
||||
messages.push(...output);
|
||||
} else if (msg.role === "toolResult") {
|
||||
const textResult = msg.content
|
||||
.filter((c): c is TextContent => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
const hasImages = msg.content.some((c): c is ImageContent => c.type === "image");
|
||||
const hasText = textResult.length > 0;
|
||||
const [callId] = msg.toolCallId.split("|");
|
||||
|
||||
let output: string | ResponseFunctionCallOutputItemList;
|
||||
if (hasImages && model.input.includes("image")) {
|
||||
const contentParts: ResponseFunctionCallOutputItemList = [];
|
||||
|
||||
if (hasText) {
|
||||
contentParts.push({
|
||||
type: "input_text",
|
||||
text: sanitizeSurrogates(textResult),
|
||||
});
|
||||
}
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "image") {
|
||||
contentParts.push({
|
||||
type: "input_image",
|
||||
detail: "auto",
|
||||
image_url: `data:${block.mimeType};base64,${block.data}`,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
output = contentParts;
|
||||
} else {
|
||||
output = sanitizeSurrogates(hasText ? textResult : "(see attached image)");
|
||||
}
|
||||
|
||||
messages.push({
|
||||
type: "function_call_output",
|
||||
call_id: callId,
|
||||
output,
|
||||
});
|
||||
}
|
||||
msgIndex++;
|
||||
}
|
||||
|
||||
return messages;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Tool conversion
|
||||
// =============================================================================
|
||||
|
||||
export function convertResponsesTools(tools: Tool[], options?: ConvertResponsesToolsOptions): OpenAITool[] {
|
||||
const strict = options?.strict === undefined ? false : options.strict;
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters as any, // TypeBox already generates JSON Schema
|
||||
strict,
|
||||
}));
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Stream processing
|
||||
// =============================================================================
|
||||
|
||||
export async function processResponsesStream<TApi extends Api>(
|
||||
openaiStream: AsyncIterable<ResponseStreamEvent>,
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
model: Model<TApi>,
|
||||
options?: OpenAIResponsesStreamOptions,
|
||||
): Promise<void> {
|
||||
let currentItem: ResponseReasoningItem | ResponseOutputMessage | ResponseFunctionToolCall | null = null;
|
||||
let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
|
||||
for await (const event of openaiStream) {
|
||||
if (event.type === "response.created") {
|
||||
output.responseId = event.response.id;
|
||||
} else if (event.type === "response.output_item.added") {
|
||||
const item = event.item;
|
||||
if (item.type === "reasoning") {
|
||||
currentItem = item;
|
||||
currentBlock = { type: "thinking", thinking: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
|
||||
} else if (item.type === "message") {
|
||||
currentItem = item;
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
|
||||
} else if (item.type === "function_call") {
|
||||
currentItem = item;
|
||||
currentBlock = {
|
||||
type: "toolCall",
|
||||
id: `${item.call_id}|${item.id}`,
|
||||
name: item.name,
|
||||
arguments: {},
|
||||
partialJson: item.arguments || "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
} else if (event.type === "response.reasoning_summary_part.added") {
|
||||
if (currentItem && currentItem.type === "reasoning") {
|
||||
currentItem.summary = currentItem.summary || [];
|
||||
currentItem.summary.push(event.part);
|
||||
}
|
||||
} else if (event.type === "response.reasoning_summary_text.delta") {
|
||||
if (currentItem?.type === "reasoning" && currentBlock?.type === "thinking") {
|
||||
currentItem.summary = currentItem.summary || [];
|
||||
const lastPart = currentItem.summary[currentItem.summary.length - 1];
|
||||
if (lastPart) {
|
||||
currentBlock.thinking += event.delta;
|
||||
lastPart.text += event.delta;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.reasoning_summary_part.done") {
|
||||
if (currentItem?.type === "reasoning" && currentBlock?.type === "thinking") {
|
||||
currentItem.summary = currentItem.summary || [];
|
||||
const lastPart = currentItem.summary[currentItem.summary.length - 1];
|
||||
if (lastPart) {
|
||||
currentBlock.thinking += "\n\n";
|
||||
lastPart.text += "\n\n";
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: "\n\n",
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.content_part.added") {
|
||||
if (currentItem?.type === "message") {
|
||||
currentItem.content = currentItem.content || [];
|
||||
// Filter out ReasoningText, only accept output_text and refusal
|
||||
if (event.part.type === "output_text" || event.part.type === "refusal") {
|
||||
currentItem.content.push(event.part);
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.output_text.delta") {
|
||||
if (currentItem?.type === "message" && currentBlock?.type === "text") {
|
||||
if (!currentItem.content || currentItem.content.length === 0) {
|
||||
continue;
|
||||
}
|
||||
const lastPart = currentItem.content[currentItem.content.length - 1];
|
||||
if (lastPart?.type === "output_text") {
|
||||
currentBlock.text += event.delta;
|
||||
lastPart.text += event.delta;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.refusal.delta") {
|
||||
if (currentItem?.type === "message" && currentBlock?.type === "text") {
|
||||
if (!currentItem.content || currentItem.content.length === 0) {
|
||||
continue;
|
||||
}
|
||||
const lastPart = currentItem.content[currentItem.content.length - 1];
|
||||
if (lastPart?.type === "refusal") {
|
||||
currentBlock.text += event.delta;
|
||||
lastPart.refusal += event.delta;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.function_call_arguments.delta") {
|
||||
if (currentItem?.type === "function_call" && currentBlock?.type === "toolCall") {
|
||||
currentBlock.partialJson += event.delta;
|
||||
currentBlock.arguments = parseStreamingJson(currentBlock.partialJson);
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.type === "response.function_call_arguments.done") {
|
||||
if (currentItem?.type === "function_call" && currentBlock?.type === "toolCall") {
|
||||
currentBlock.partialJson = event.arguments;
|
||||
currentBlock.arguments = parseStreamingJson(currentBlock.partialJson);
|
||||
}
|
||||
} else if (event.type === "response.output_item.done") {
|
||||
const item = event.item;
|
||||
|
||||
if (item.type === "reasoning" && currentBlock?.type === "thinking") {
|
||||
currentBlock.thinking = item.summary?.map((s) => s.text).join("\n\n") || "";
|
||||
currentBlock.thinkingSignature = JSON.stringify(item);
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
currentBlock = null;
|
||||
} else if (item.type === "message" && currentBlock?.type === "text") {
|
||||
currentBlock.text = item.content.map((c) => (c.type === "output_text" ? c.text : c.refusal)).join("");
|
||||
currentBlock.textSignature = encodeTextSignatureV1(item.id, item.phase ?? undefined);
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
currentBlock = null;
|
||||
} else if (item.type === "function_call") {
|
||||
const args =
|
||||
currentBlock?.type === "toolCall" && currentBlock.partialJson
|
||||
? parseStreamingJson(currentBlock.partialJson)
|
||||
: parseStreamingJson(item.arguments || "{}");
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: `${item.call_id}|${item.id}`,
|
||||
name: item.name,
|
||||
arguments: args,
|
||||
};
|
||||
|
||||
currentBlock = null;
|
||||
stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output });
|
||||
}
|
||||
} else if (event.type === "response.completed") {
|
||||
const response = event.response;
|
||||
if (response?.id) {
|
||||
output.responseId = response.id;
|
||||
}
|
||||
if (response?.usage) {
|
||||
const cachedTokens = response.usage.input_tokens_details?.cached_tokens || 0;
|
||||
output.usage = {
|
||||
// OpenAI includes cached tokens in input_tokens, so subtract to get non-cached input
|
||||
input: (response.usage.input_tokens || 0) - cachedTokens,
|
||||
output: response.usage.output_tokens || 0,
|
||||
cacheRead: cachedTokens,
|
||||
cacheWrite: 0,
|
||||
totalTokens: response.usage.total_tokens || 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
}
|
||||
calculateCost(model, output.usage);
|
||||
if (options?.applyServiceTierPricing) {
|
||||
const serviceTier = response?.service_tier ?? options.serviceTier;
|
||||
options.applyServiceTierPricing(output.usage, serviceTier);
|
||||
}
|
||||
// Map status to stop reason
|
||||
output.stopReason = mapStopReason(response?.status);
|
||||
if (output.content.some((b) => b.type === "toolCall") && output.stopReason === "stop") {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
} else if (event.type === "error") {
|
||||
throw new Error(`Error Code ${event.code}: ${event.message}` || "Unknown error");
|
||||
} else if (event.type === "response.failed") {
|
||||
const error = event.response?.error;
|
||||
const details = event.response?.incomplete_details;
|
||||
const msg = error
|
||||
? `${error.code || "unknown"}: ${error.message || "no message"}`
|
||||
: details?.reason
|
||||
? `incomplete: ${details.reason}`
|
||||
: "Unknown error (no error details in response)";
|
||||
throw new Error(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function mapStopReason(status: OpenAI.Responses.ResponseStatus | undefined): StopReason {
|
||||
if (!status) return "stop";
|
||||
switch (status) {
|
||||
case "completed":
|
||||
return "stop";
|
||||
case "incomplete":
|
||||
return "length";
|
||||
case "failed":
|
||||
case "cancelled":
|
||||
return "error";
|
||||
// These two are wonky ...
|
||||
case "in_progress":
|
||||
case "queued":
|
||||
return "stop";
|
||||
default: {
|
||||
const _exhaustive: never = status;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
251
packages/ai/src/providers/openai-responses.ts
Normal file
251
packages/ai/src/providers/openai-responses.ts
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
import OpenAI from "openai";
|
||||
import type { ResponseCreateParamsStreaming } from "openai/resources/responses/responses.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { supportsXhigh } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
CacheRetention,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
Usage,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./github-copilot-headers.js";
|
||||
import { convertResponsesMessages, convertResponsesTools, processResponsesStream } from "./openai-responses-shared.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
|
||||
const OPENAI_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode"]);
|
||||
|
||||
/**
|
||||
* Resolve cache retention preference.
|
||||
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||
*/
|
||||
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
}
|
||||
|
||||
/**
|
||||
* Get prompt cache retention based on cacheRetention and base URL.
|
||||
* Only applies to direct OpenAI API calls (api.openai.com).
|
||||
*/
|
||||
function getPromptCacheRetention(baseUrl: string, cacheRetention: CacheRetention): "24h" | undefined {
|
||||
if (cacheRetention !== "long") {
|
||||
return undefined;
|
||||
}
|
||||
if (baseUrl.includes("api.openai.com")) {
|
||||
return "24h";
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// OpenAI Responses-specific options
|
||||
export interface OpenAIResponsesOptions extends StreamOptions {
|
||||
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
reasoningSummary?: "auto" | "detailed" | "concise" | null;
|
||||
serviceTier?: ResponseCreateParamsStreaming["service_tier"];
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate function for OpenAI Responses API
|
||||
*/
|
||||
export const streamOpenAIResponses: StreamFunction<"openai-responses", OpenAIResponsesOptions> = (
|
||||
model: Model<"openai-responses">,
|
||||
context: Context,
|
||||
options?: OpenAIResponsesOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
// Start async processing
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
// Create OpenAI client
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const client = createClient(model, context, apiKey, options?.headers);
|
||||
let params = buildParams(model, context, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as ResponseCreateParamsStreaming;
|
||||
}
|
||||
const openaiStream = await client.responses.create(
|
||||
params,
|
||||
options?.signal ? { signal: options.signal } : undefined,
|
||||
);
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
await processResponsesStream(openaiStream, output, stream, model, {
|
||||
serviceTier: options?.serviceTier,
|
||||
applyServiceTierPricing,
|
||||
});
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) delete (block as { index?: number }).index;
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleOpenAIResponses: StreamFunction<"openai-responses", SimpleStreamOptions> = (
|
||||
model: Model<"openai-responses">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const reasoningEffort = supportsXhigh(model) ? options?.reasoning : clampReasoning(options?.reasoning);
|
||||
|
||||
return streamOpenAIResponses(model, context, {
|
||||
...base,
|
||||
reasoningEffort,
|
||||
} satisfies OpenAIResponsesOptions);
|
||||
};
|
||||
|
||||
function createClient(
|
||||
model: Model<"openai-responses">,
|
||||
context: Context,
|
||||
apiKey?: string,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.OPENAI_API_KEY;
|
||||
}
|
||||
|
||||
const headers = { ...model.headers };
|
||||
if (model.provider === "github-copilot") {
|
||||
const hasImages = hasCopilotVisionInput(context.messages);
|
||||
const copilotHeaders = buildCopilotDynamicHeaders({
|
||||
messages: context.messages,
|
||||
hasImages,
|
||||
});
|
||||
Object.assign(headers, copilotHeaders);
|
||||
}
|
||||
|
||||
// Merge options headers last so they can override defaults
|
||||
if (optionsHeaders) {
|
||||
Object.assign(headers, optionsHeaders);
|
||||
}
|
||||
|
||||
return new OpenAI({
|
||||
apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: headers,
|
||||
});
|
||||
}
|
||||
|
||||
function buildParams(model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions) {
|
||||
const messages = convertResponsesMessages(model, context, OPENAI_TOOL_CALL_PROVIDERS);
|
||||
|
||||
const cacheRetention = resolveCacheRetention(options?.cacheRetention);
|
||||
const params: ResponseCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
input: messages,
|
||||
stream: true,
|
||||
prompt_cache_key: cacheRetention === "none" ? undefined : options?.sessionId,
|
||||
prompt_cache_retention: getPromptCacheRetention(model.baseUrl, cacheRetention),
|
||||
store: false,
|
||||
};
|
||||
|
||||
if (options?.maxTokens) {
|
||||
params.max_output_tokens = options?.maxTokens;
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
if (options?.serviceTier !== undefined) {
|
||||
params.service_tier = options.serviceTier;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertResponsesTools(context.tools);
|
||||
}
|
||||
|
||||
if (model.reasoning) {
|
||||
if (options?.reasoningEffort || options?.reasoningSummary) {
|
||||
params.reasoning = {
|
||||
effort: options?.reasoningEffort || "medium",
|
||||
summary: options?.reasoningSummary || "auto",
|
||||
};
|
||||
params.include = ["reasoning.encrypted_content"];
|
||||
} else {
|
||||
params.reasoning = { effort: "none" };
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
function getServiceTierCostMultiplier(serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined): number {
|
||||
switch (serviceTier) {
|
||||
case "flex":
|
||||
return 0.5;
|
||||
case "priority":
|
||||
return 2;
|
||||
default:
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
function applyServiceTierPricing(usage: Usage, serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined) {
|
||||
const multiplier = getServiceTierCostMultiplier(serviceTier);
|
||||
if (multiplier === 1) return;
|
||||
|
||||
usage.cost.input *= multiplier;
|
||||
usage.cost.output *= multiplier;
|
||||
usage.cost.cacheRead *= multiplier;
|
||||
usage.cost.cacheWrite *= multiplier;
|
||||
usage.cost.total = usage.cost.input + usage.cost.output + usage.cost.cacheRead + usage.cost.cacheWrite;
|
||||
}
|
||||
433
packages/ai/src/providers/register-builtins.ts
Normal file
433
packages/ai/src/providers/register-builtins.ts
Normal file
|
|
@ -0,0 +1,433 @@
|
|||
import { clearApiProviders, registerApiProvider } from "../api-registry.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
AssistantMessageEvent,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import type { BedrockOptions } from "./amazon-bedrock.js";
|
||||
import type { AnthropicOptions } from "./anthropic.js";
|
||||
import type { AzureOpenAIResponsesOptions } from "./azure-openai-responses.js";
|
||||
import type { GoogleOptions } from "./google.js";
|
||||
import type { GoogleGeminiCliOptions } from "./google-gemini-cli.js";
|
||||
import type { GoogleVertexOptions } from "./google-vertex.js";
|
||||
import type { MistralOptions } from "./mistral.js";
|
||||
import type { OpenAICodexResponsesOptions } from "./openai-codex-responses.js";
|
||||
import type { OpenAICompletionsOptions } from "./openai-completions.js";
|
||||
import type { OpenAIResponsesOptions } from "./openai-responses.js";
|
||||
|
||||
interface LazyProviderModule<
|
||||
TApi extends Api,
|
||||
TOptions extends StreamOptions,
|
||||
TSimpleOptions extends SimpleStreamOptions,
|
||||
> {
|
||||
stream: (model: Model<TApi>, context: Context, options?: TOptions) => AsyncIterable<AssistantMessageEvent>;
|
||||
streamSimple: (
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: TSimpleOptions,
|
||||
) => AsyncIterable<AssistantMessageEvent>;
|
||||
}
|
||||
|
||||
interface AnthropicProviderModule {
|
||||
streamAnthropic: StreamFunction<"anthropic-messages", AnthropicOptions>;
|
||||
streamSimpleAnthropic: StreamFunction<"anthropic-messages", SimpleStreamOptions>;
|
||||
}
|
||||
|
||||
interface AzureOpenAIResponsesProviderModule {
|
||||
streamAzureOpenAIResponses: StreamFunction<"azure-openai-responses", AzureOpenAIResponsesOptions>;
|
||||
streamSimpleAzureOpenAIResponses: StreamFunction<"azure-openai-responses", SimpleStreamOptions>;
|
||||
}
|
||||
|
||||
interface GoogleProviderModule {
|
||||
streamGoogle: StreamFunction<"google-generative-ai", GoogleOptions>;
|
||||
streamSimpleGoogle: StreamFunction<"google-generative-ai", SimpleStreamOptions>;
|
||||
}
|
||||
|
||||
interface GoogleGeminiCliProviderModule {
|
||||
streamGoogleGeminiCli: StreamFunction<"google-gemini-cli", GoogleGeminiCliOptions>;
|
||||
streamSimpleGoogleGeminiCli: StreamFunction<"google-gemini-cli", SimpleStreamOptions>;
|
||||
}
|
||||
|
||||
interface GoogleVertexProviderModule {
|
||||
streamGoogleVertex: StreamFunction<"google-vertex", GoogleVertexOptions>;
|
||||
streamSimpleGoogleVertex: StreamFunction<"google-vertex", SimpleStreamOptions>;
|
||||
}
|
||||
|
||||
interface MistralProviderModule {
|
||||
streamMistral: StreamFunction<"mistral-conversations", MistralOptions>;
|
||||
streamSimpleMistral: StreamFunction<"mistral-conversations", SimpleStreamOptions>;
|
||||
}
|
||||
|
||||
interface OpenAICodexResponsesProviderModule {
|
||||
streamOpenAICodexResponses: StreamFunction<"openai-codex-responses", OpenAICodexResponsesOptions>;
|
||||
streamSimpleOpenAICodexResponses: StreamFunction<"openai-codex-responses", SimpleStreamOptions>;
|
||||
}
|
||||
|
||||
interface OpenAICompletionsProviderModule {
|
||||
streamOpenAICompletions: StreamFunction<"openai-completions", OpenAICompletionsOptions>;
|
||||
streamSimpleOpenAICompletions: StreamFunction<"openai-completions", SimpleStreamOptions>;
|
||||
}
|
||||
|
||||
interface OpenAIResponsesProviderModule {
|
||||
streamOpenAIResponses: StreamFunction<"openai-responses", OpenAIResponsesOptions>;
|
||||
streamSimpleOpenAIResponses: StreamFunction<"openai-responses", SimpleStreamOptions>;
|
||||
}
|
||||
|
||||
interface BedrockProviderModule {
|
||||
streamBedrock: (
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: BedrockOptions,
|
||||
) => AsyncIterable<AssistantMessageEvent>;
|
||||
streamSimpleBedrock: (
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
) => AsyncIterable<AssistantMessageEvent>;
|
||||
}
|
||||
|
||||
const importNodeOnlyProvider = (specifier: string): Promise<unknown> => import(specifier);
|
||||
|
||||
let anthropicProviderModulePromise:
|
||||
| Promise<LazyProviderModule<"anthropic-messages", AnthropicOptions, SimpleStreamOptions>>
|
||||
| undefined;
|
||||
let azureOpenAIResponsesProviderModulePromise:
|
||||
| Promise<LazyProviderModule<"azure-openai-responses", AzureOpenAIResponsesOptions, SimpleStreamOptions>>
|
||||
| undefined;
|
||||
let googleProviderModulePromise:
|
||||
| Promise<LazyProviderModule<"google-generative-ai", GoogleOptions, SimpleStreamOptions>>
|
||||
| undefined;
|
||||
let googleGeminiCliProviderModulePromise:
|
||||
| Promise<LazyProviderModule<"google-gemini-cli", GoogleGeminiCliOptions, SimpleStreamOptions>>
|
||||
| undefined;
|
||||
let googleVertexProviderModulePromise:
|
||||
| Promise<LazyProviderModule<"google-vertex", GoogleVertexOptions, SimpleStreamOptions>>
|
||||
| undefined;
|
||||
let mistralProviderModulePromise:
|
||||
| Promise<LazyProviderModule<"mistral-conversations", MistralOptions, SimpleStreamOptions>>
|
||||
| undefined;
|
||||
let openAICodexResponsesProviderModulePromise:
|
||||
| Promise<LazyProviderModule<"openai-codex-responses", OpenAICodexResponsesOptions, SimpleStreamOptions>>
|
||||
| undefined;
|
||||
let openAICompletionsProviderModulePromise:
|
||||
| Promise<LazyProviderModule<"openai-completions", OpenAICompletionsOptions, SimpleStreamOptions>>
|
||||
| undefined;
|
||||
let openAIResponsesProviderModulePromise:
|
||||
| Promise<LazyProviderModule<"openai-responses", OpenAIResponsesOptions, SimpleStreamOptions>>
|
||||
| undefined;
|
||||
let bedrockProviderModuleOverride:
|
||||
| LazyProviderModule<"bedrock-converse-stream", BedrockOptions, SimpleStreamOptions>
|
||||
| undefined;
|
||||
let bedrockProviderModulePromise:
|
||||
| Promise<LazyProviderModule<"bedrock-converse-stream", BedrockOptions, SimpleStreamOptions>>
|
||||
| undefined;
|
||||
|
||||
export function setBedrockProviderModule(module: BedrockProviderModule): void {
|
||||
bedrockProviderModuleOverride = {
|
||||
stream: module.streamBedrock,
|
||||
streamSimple: module.streamSimpleBedrock,
|
||||
};
|
||||
}
|
||||
|
||||
function forwardStream(target: AssistantMessageEventStream, source: AsyncIterable<AssistantMessageEvent>): void {
|
||||
(async () => {
|
||||
for await (const event of source) {
|
||||
target.push(event);
|
||||
}
|
||||
target.end();
|
||||
})();
|
||||
}
|
||||
|
||||
function createLazyLoadErrorMessage<TApi extends Api>(model: Model<TApi>, error: unknown): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "error",
|
||||
errorMessage: error instanceof Error ? error.message : String(error),
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
function createLazyStream<TApi extends Api, TOptions extends StreamOptions, TSimpleOptions extends SimpleStreamOptions>(
|
||||
loadModule: () => Promise<LazyProviderModule<TApi, TOptions, TSimpleOptions>>,
|
||||
): StreamFunction<TApi, TOptions> {
|
||||
return (model, context, options) => {
|
||||
const outer = new AssistantMessageEventStream();
|
||||
|
||||
loadModule()
|
||||
.then((module) => {
|
||||
const inner = module.stream(model, context, options);
|
||||
forwardStream(outer, inner);
|
||||
})
|
||||
.catch((error) => {
|
||||
const message = createLazyLoadErrorMessage(model, error);
|
||||
outer.push({ type: "error", reason: "error", error: message });
|
||||
outer.end(message);
|
||||
});
|
||||
|
||||
return outer;
|
||||
};
|
||||
}
|
||||
|
||||
function createLazySimpleStream<
|
||||
TApi extends Api,
|
||||
TOptions extends StreamOptions,
|
||||
TSimpleOptions extends SimpleStreamOptions,
|
||||
>(loadModule: () => Promise<LazyProviderModule<TApi, TOptions, TSimpleOptions>>): StreamFunction<TApi, TSimpleOptions> {
|
||||
return (model, context, options) => {
|
||||
const outer = new AssistantMessageEventStream();
|
||||
|
||||
loadModule()
|
||||
.then((module) => {
|
||||
const inner = module.streamSimple(model, context, options);
|
||||
forwardStream(outer, inner);
|
||||
})
|
||||
.catch((error) => {
|
||||
const message = createLazyLoadErrorMessage(model, error);
|
||||
outer.push({ type: "error", reason: "error", error: message });
|
||||
outer.end(message);
|
||||
});
|
||||
|
||||
return outer;
|
||||
};
|
||||
}
|
||||
|
||||
function loadAnthropicProviderModule(): Promise<
|
||||
LazyProviderModule<"anthropic-messages", AnthropicOptions, SimpleStreamOptions>
|
||||
> {
|
||||
anthropicProviderModulePromise ||= import("./anthropic.js").then((module) => {
|
||||
const provider = module as AnthropicProviderModule;
|
||||
return {
|
||||
stream: provider.streamAnthropic,
|
||||
streamSimple: provider.streamSimpleAnthropic,
|
||||
};
|
||||
});
|
||||
return anthropicProviderModulePromise;
|
||||
}
|
||||
|
||||
function loadAzureOpenAIResponsesProviderModule(): Promise<
|
||||
LazyProviderModule<"azure-openai-responses", AzureOpenAIResponsesOptions, SimpleStreamOptions>
|
||||
> {
|
||||
azureOpenAIResponsesProviderModulePromise ||= import("./azure-openai-responses.js").then((module) => {
|
||||
const provider = module as AzureOpenAIResponsesProviderModule;
|
||||
return {
|
||||
stream: provider.streamAzureOpenAIResponses,
|
||||
streamSimple: provider.streamSimpleAzureOpenAIResponses,
|
||||
};
|
||||
});
|
||||
return azureOpenAIResponsesProviderModulePromise;
|
||||
}
|
||||
|
||||
function loadGoogleProviderModule(): Promise<
|
||||
LazyProviderModule<"google-generative-ai", GoogleOptions, SimpleStreamOptions>
|
||||
> {
|
||||
googleProviderModulePromise ||= import("./google.js").then((module) => {
|
||||
const provider = module as GoogleProviderModule;
|
||||
return {
|
||||
stream: provider.streamGoogle,
|
||||
streamSimple: provider.streamSimpleGoogle,
|
||||
};
|
||||
});
|
||||
return googleProviderModulePromise;
|
||||
}
|
||||
|
||||
function loadGoogleGeminiCliProviderModule(): Promise<
|
||||
LazyProviderModule<"google-gemini-cli", GoogleGeminiCliOptions, SimpleStreamOptions>
|
||||
> {
|
||||
googleGeminiCliProviderModulePromise ||= import("./google-gemini-cli.js").then((module) => {
|
||||
const provider = module as GoogleGeminiCliProviderModule;
|
||||
return {
|
||||
stream: provider.streamGoogleGeminiCli,
|
||||
streamSimple: provider.streamSimpleGoogleGeminiCli,
|
||||
};
|
||||
});
|
||||
return googleGeminiCliProviderModulePromise;
|
||||
}
|
||||
|
||||
function loadGoogleVertexProviderModule(): Promise<
|
||||
LazyProviderModule<"google-vertex", GoogleVertexOptions, SimpleStreamOptions>
|
||||
> {
|
||||
googleVertexProviderModulePromise ||= import("./google-vertex.js").then((module) => {
|
||||
const provider = module as GoogleVertexProviderModule;
|
||||
return {
|
||||
stream: provider.streamGoogleVertex,
|
||||
streamSimple: provider.streamSimpleGoogleVertex,
|
||||
};
|
||||
});
|
||||
return googleVertexProviderModulePromise;
|
||||
}
|
||||
|
||||
function loadMistralProviderModule(): Promise<
|
||||
LazyProviderModule<"mistral-conversations", MistralOptions, SimpleStreamOptions>
|
||||
> {
|
||||
mistralProviderModulePromise ||= import("./mistral.js").then((module) => {
|
||||
const provider = module as MistralProviderModule;
|
||||
return {
|
||||
stream: provider.streamMistral,
|
||||
streamSimple: provider.streamSimpleMistral,
|
||||
};
|
||||
});
|
||||
return mistralProviderModulePromise;
|
||||
}
|
||||
|
||||
function loadOpenAICodexResponsesProviderModule(): Promise<
|
||||
LazyProviderModule<"openai-codex-responses", OpenAICodexResponsesOptions, SimpleStreamOptions>
|
||||
> {
|
||||
openAICodexResponsesProviderModulePromise ||= import("./openai-codex-responses.js").then((module) => {
|
||||
const provider = module as OpenAICodexResponsesProviderModule;
|
||||
return {
|
||||
stream: provider.streamOpenAICodexResponses,
|
||||
streamSimple: provider.streamSimpleOpenAICodexResponses,
|
||||
};
|
||||
});
|
||||
return openAICodexResponsesProviderModulePromise;
|
||||
}
|
||||
|
||||
function loadOpenAICompletionsProviderModule(): Promise<
|
||||
LazyProviderModule<"openai-completions", OpenAICompletionsOptions, SimpleStreamOptions>
|
||||
> {
|
||||
openAICompletionsProviderModulePromise ||= import("./openai-completions.js").then((module) => {
|
||||
const provider = module as OpenAICompletionsProviderModule;
|
||||
return {
|
||||
stream: provider.streamOpenAICompletions,
|
||||
streamSimple: provider.streamSimpleOpenAICompletions,
|
||||
};
|
||||
});
|
||||
return openAICompletionsProviderModulePromise;
|
||||
}
|
||||
|
||||
function loadOpenAIResponsesProviderModule(): Promise<
|
||||
LazyProviderModule<"openai-responses", OpenAIResponsesOptions, SimpleStreamOptions>
|
||||
> {
|
||||
openAIResponsesProviderModulePromise ||= import("./openai-responses.js").then((module) => {
|
||||
const provider = module as OpenAIResponsesProviderModule;
|
||||
return {
|
||||
stream: provider.streamOpenAIResponses,
|
||||
streamSimple: provider.streamSimpleOpenAIResponses,
|
||||
};
|
||||
});
|
||||
return openAIResponsesProviderModulePromise;
|
||||
}
|
||||
|
||||
function loadBedrockProviderModule(): Promise<
|
||||
LazyProviderModule<"bedrock-converse-stream", BedrockOptions, SimpleStreamOptions>
|
||||
> {
|
||||
if (bedrockProviderModuleOverride) {
|
||||
return Promise.resolve(bedrockProviderModuleOverride);
|
||||
}
|
||||
bedrockProviderModulePromise ||= importNodeOnlyProvider("./amazon-bedrock.js").then((module) => {
|
||||
const provider = module as BedrockProviderModule;
|
||||
return {
|
||||
stream: provider.streamBedrock,
|
||||
streamSimple: provider.streamSimpleBedrock,
|
||||
};
|
||||
});
|
||||
return bedrockProviderModulePromise;
|
||||
}
|
||||
|
||||
export const streamAnthropic = createLazyStream(loadAnthropicProviderModule);
|
||||
export const streamSimpleAnthropic = createLazySimpleStream(loadAnthropicProviderModule);
|
||||
export const streamAzureOpenAIResponses = createLazyStream(loadAzureOpenAIResponsesProviderModule);
|
||||
export const streamSimpleAzureOpenAIResponses = createLazySimpleStream(loadAzureOpenAIResponsesProviderModule);
|
||||
export const streamGoogle = createLazyStream(loadGoogleProviderModule);
|
||||
export const streamSimpleGoogle = createLazySimpleStream(loadGoogleProviderModule);
|
||||
export const streamGoogleGeminiCli = createLazyStream(loadGoogleGeminiCliProviderModule);
|
||||
export const streamSimpleGoogleGeminiCli = createLazySimpleStream(loadGoogleGeminiCliProviderModule);
|
||||
export const streamGoogleVertex = createLazyStream(loadGoogleVertexProviderModule);
|
||||
export const streamSimpleGoogleVertex = createLazySimpleStream(loadGoogleVertexProviderModule);
|
||||
export const streamMistral = createLazyStream(loadMistralProviderModule);
|
||||
export const streamSimpleMistral = createLazySimpleStream(loadMistralProviderModule);
|
||||
export const streamOpenAICodexResponses = createLazyStream(loadOpenAICodexResponsesProviderModule);
|
||||
export const streamSimpleOpenAICodexResponses = createLazySimpleStream(loadOpenAICodexResponsesProviderModule);
|
||||
export const streamOpenAICompletions = createLazyStream(loadOpenAICompletionsProviderModule);
|
||||
export const streamSimpleOpenAICompletions = createLazySimpleStream(loadOpenAICompletionsProviderModule);
|
||||
export const streamOpenAIResponses = createLazyStream(loadOpenAIResponsesProviderModule);
|
||||
export const streamSimpleOpenAIResponses = createLazySimpleStream(loadOpenAIResponsesProviderModule);
|
||||
const streamBedrockLazy = createLazyStream(loadBedrockProviderModule);
|
||||
const streamSimpleBedrockLazy = createLazySimpleStream(loadBedrockProviderModule);
|
||||
|
||||
export function registerBuiltInApiProviders(): void {
|
||||
registerApiProvider({
|
||||
api: "anthropic-messages",
|
||||
stream: streamAnthropic,
|
||||
streamSimple: streamSimpleAnthropic,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "openai-completions",
|
||||
stream: streamOpenAICompletions,
|
||||
streamSimple: streamSimpleOpenAICompletions,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "mistral-conversations",
|
||||
stream: streamMistral,
|
||||
streamSimple: streamSimpleMistral,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "openai-responses",
|
||||
stream: streamOpenAIResponses,
|
||||
streamSimple: streamSimpleOpenAIResponses,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "azure-openai-responses",
|
||||
stream: streamAzureOpenAIResponses,
|
||||
streamSimple: streamSimpleAzureOpenAIResponses,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "openai-codex-responses",
|
||||
stream: streamOpenAICodexResponses,
|
||||
streamSimple: streamSimpleOpenAICodexResponses,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "google-generative-ai",
|
||||
stream: streamGoogle,
|
||||
streamSimple: streamSimpleGoogle,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "google-gemini-cli",
|
||||
stream: streamGoogleGeminiCli,
|
||||
streamSimple: streamSimpleGoogleGeminiCli,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "google-vertex",
|
||||
stream: streamGoogleVertex,
|
||||
streamSimple: streamSimpleGoogleVertex,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "bedrock-converse-stream",
|
||||
stream: streamBedrockLazy,
|
||||
streamSimple: streamSimpleBedrockLazy,
|
||||
});
|
||||
}
|
||||
|
||||
export function resetApiProviders(): void {
|
||||
clearApiProviders();
|
||||
registerBuiltInApiProviders();
|
||||
}
|
||||
|
||||
registerBuiltInApiProviders();
|
||||
46
packages/ai/src/providers/simple-options.ts
Normal file
46
packages/ai/src/providers/simple-options.ts
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import type { Api, Model, SimpleStreamOptions, StreamOptions, ThinkingBudgets, ThinkingLevel } from "../types.js";
|
||||
|
||||
export function buildBaseOptions(model: Model<Api>, options?: SimpleStreamOptions, apiKey?: string): StreamOptions {
|
||||
return {
|
||||
temperature: options?.temperature,
|
||||
maxTokens: options?.maxTokens || Math.min(model.maxTokens, 32000),
|
||||
signal: options?.signal,
|
||||
apiKey: apiKey || options?.apiKey,
|
||||
cacheRetention: options?.cacheRetention,
|
||||
sessionId: options?.sessionId,
|
||||
headers: options?.headers,
|
||||
onPayload: options?.onPayload,
|
||||
maxRetryDelayMs: options?.maxRetryDelayMs,
|
||||
metadata: options?.metadata,
|
||||
};
|
||||
}
|
||||
|
||||
export function clampReasoning(effort: ThinkingLevel | undefined): Exclude<ThinkingLevel, "xhigh"> | undefined {
|
||||
return effort === "xhigh" ? "high" : effort;
|
||||
}
|
||||
|
||||
export function adjustMaxTokensForThinking(
|
||||
baseMaxTokens: number,
|
||||
modelMaxTokens: number,
|
||||
reasoningLevel: ThinkingLevel,
|
||||
customBudgets?: ThinkingBudgets,
|
||||
): { maxTokens: number; thinkingBudget: number } {
|
||||
const defaultBudgets: ThinkingBudgets = {
|
||||
minimal: 1024,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 16384,
|
||||
};
|
||||
const budgets = { ...defaultBudgets, ...customBudgets };
|
||||
|
||||
const minOutputTokens = 1024;
|
||||
const level = clampReasoning(reasoningLevel)!;
|
||||
let thinkingBudget = budgets[level]!;
|
||||
const maxTokens = Math.min(baseMaxTokens + thinkingBudget, modelMaxTokens);
|
||||
|
||||
if (maxTokens <= thinkingBudget) {
|
||||
thinkingBudget = Math.max(0, maxTokens - minOutputTokens);
|
||||
}
|
||||
|
||||
return { maxTokens, thinkingBudget };
|
||||
}
|
||||
172
packages/ai/src/providers/transform-messages.ts
Normal file
172
packages/ai/src/providers/transform-messages.ts
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
import type { Api, AssistantMessage, Message, Model, ToolCall, ToolResultMessage } from "../types.js";
|
||||
|
||||
/**
|
||||
* Normalize tool call ID for cross-provider compatibility.
|
||||
* OpenAI Responses API generates IDs that are 450+ chars with special characters like `|`.
|
||||
* Anthropic APIs require IDs matching ^[a-zA-Z0-9_-]+$ (max 64 chars).
|
||||
*/
|
||||
export function transformMessages<TApi extends Api>(
|
||||
messages: Message[],
|
||||
model: Model<TApi>,
|
||||
normalizeToolCallId?: (id: string, model: Model<TApi>, source: AssistantMessage) => string,
|
||||
): Message[] {
|
||||
// Build a map of original tool call IDs to normalized IDs
|
||||
const toolCallIdMap = new Map<string, string>();
|
||||
|
||||
// First pass: transform messages (thinking blocks, tool call ID normalization)
|
||||
const transformed = messages.map((msg) => {
|
||||
// User messages pass through unchanged
|
||||
if (msg.role === "user") {
|
||||
return msg;
|
||||
}
|
||||
|
||||
// Handle toolResult messages - normalize toolCallId if we have a mapping
|
||||
if (msg.role === "toolResult") {
|
||||
const normalizedId = toolCallIdMap.get(msg.toolCallId);
|
||||
if (normalizedId && normalizedId !== msg.toolCallId) {
|
||||
return { ...msg, toolCallId: normalizedId };
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
// Assistant messages need transformation check
|
||||
if (msg.role === "assistant") {
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
const isSameModel =
|
||||
assistantMsg.provider === model.provider &&
|
||||
assistantMsg.api === model.api &&
|
||||
assistantMsg.model === model.id;
|
||||
|
||||
const transformedContent = assistantMsg.content.flatMap((block) => {
|
||||
if (block.type === "thinking") {
|
||||
// Redacted thinking is opaque encrypted content, only valid for the same model.
|
||||
// Drop it for cross-model to avoid API errors.
|
||||
if (block.redacted) {
|
||||
return isSameModel ? block : [];
|
||||
}
|
||||
// For same model: keep thinking blocks with signatures (needed for replay)
|
||||
// even if the thinking text is empty (OpenAI encrypted reasoning)
|
||||
if (isSameModel && block.thinkingSignature) return block;
|
||||
// Skip empty thinking blocks, convert others to plain text
|
||||
if (!block.thinking || block.thinking.trim() === "") return [];
|
||||
if (isSameModel) return block;
|
||||
return {
|
||||
type: "text" as const,
|
||||
text: block.thinking,
|
||||
};
|
||||
}
|
||||
|
||||
if (block.type === "text") {
|
||||
if (isSameModel) return block;
|
||||
return {
|
||||
type: "text" as const,
|
||||
text: block.text,
|
||||
};
|
||||
}
|
||||
|
||||
if (block.type === "toolCall") {
|
||||
const toolCall = block as ToolCall;
|
||||
let normalizedToolCall: ToolCall = toolCall;
|
||||
|
||||
if (!isSameModel && toolCall.thoughtSignature) {
|
||||
normalizedToolCall = { ...toolCall };
|
||||
delete (normalizedToolCall as { thoughtSignature?: string }).thoughtSignature;
|
||||
}
|
||||
|
||||
if (!isSameModel && normalizeToolCallId) {
|
||||
const normalizedId = normalizeToolCallId(toolCall.id, model, assistantMsg);
|
||||
if (normalizedId !== toolCall.id) {
|
||||
toolCallIdMap.set(toolCall.id, normalizedId);
|
||||
normalizedToolCall = { ...normalizedToolCall, id: normalizedId };
|
||||
}
|
||||
}
|
||||
|
||||
return normalizedToolCall;
|
||||
}
|
||||
|
||||
return block;
|
||||
});
|
||||
|
||||
return {
|
||||
...assistantMsg,
|
||||
content: transformedContent,
|
||||
};
|
||||
}
|
||||
return msg;
|
||||
});
|
||||
|
||||
// Second pass: insert synthetic empty tool results for orphaned tool calls
|
||||
// This preserves thinking signatures and satisfies API requirements
|
||||
const result: Message[] = [];
|
||||
let pendingToolCalls: ToolCall[] = [];
|
||||
let existingToolResultIds = new Set<string>();
|
||||
|
||||
for (let i = 0; i < transformed.length; i++) {
|
||||
const msg = transformed[i];
|
||||
|
||||
if (msg.role === "assistant") {
|
||||
// If we have pending orphaned tool calls from a previous assistant, insert synthetic results now
|
||||
if (pendingToolCalls.length > 0) {
|
||||
for (const tc of pendingToolCalls) {
|
||||
if (!existingToolResultIds.has(tc.id)) {
|
||||
result.push({
|
||||
role: "toolResult",
|
||||
toolCallId: tc.id,
|
||||
toolName: tc.name,
|
||||
content: [{ type: "text", text: "No result provided" }],
|
||||
isError: true,
|
||||
timestamp: Date.now(),
|
||||
} as ToolResultMessage);
|
||||
}
|
||||
}
|
||||
pendingToolCalls = [];
|
||||
existingToolResultIds = new Set();
|
||||
}
|
||||
|
||||
// Skip errored/aborted assistant messages entirely.
|
||||
// These are incomplete turns that shouldn't be replayed:
|
||||
// - May have partial content (reasoning without message, incomplete tool calls)
|
||||
// - Replaying them can cause API errors (e.g., OpenAI "reasoning without following item")
|
||||
// - The model should retry from the last valid state
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
if (assistantMsg.stopReason === "error" || assistantMsg.stopReason === "aborted") {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Track tool calls from this assistant message
|
||||
const toolCalls = assistantMsg.content.filter((b) => b.type === "toolCall") as ToolCall[];
|
||||
if (toolCalls.length > 0) {
|
||||
pendingToolCalls = toolCalls;
|
||||
existingToolResultIds = new Set();
|
||||
}
|
||||
|
||||
result.push(msg);
|
||||
} else if (msg.role === "toolResult") {
|
||||
existingToolResultIds.add(msg.toolCallId);
|
||||
result.push(msg);
|
||||
} else if (msg.role === "user") {
|
||||
// User message interrupts tool flow - insert synthetic results for orphaned calls
|
||||
if (pendingToolCalls.length > 0) {
|
||||
for (const tc of pendingToolCalls) {
|
||||
if (!existingToolResultIds.has(tc.id)) {
|
||||
result.push({
|
||||
role: "toolResult",
|
||||
toolCallId: tc.id,
|
||||
toolName: tc.name,
|
||||
content: [{ type: "text", text: "No result provided" }],
|
||||
isError: true,
|
||||
timestamp: Date.now(),
|
||||
} as ToolResultMessage);
|
||||
}
|
||||
}
|
||||
pendingToolCalls = [];
|
||||
existingToolResultIds = new Set();
|
||||
}
|
||||
result.push(msg);
|
||||
} else {
|
||||
result.push(msg);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
59
packages/ai/src/stream.ts
Normal file
59
packages/ai/src/stream.ts
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import "./providers/register-builtins.js";
|
||||
|
||||
import { getApiProvider } from "./api-registry.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
AssistantMessageEventStream,
|
||||
Context,
|
||||
Model,
|
||||
ProviderStreamOptions,
|
||||
SimpleStreamOptions,
|
||||
StreamOptions,
|
||||
} from "./types.js";
|
||||
|
||||
export { getEnvApiKey } from "./env-api-keys.js";
|
||||
|
||||
function resolveApiProvider(api: Api) {
|
||||
const provider = getApiProvider(api);
|
||||
if (!provider) {
|
||||
throw new Error(`No API provider registered for api: ${api}`);
|
||||
}
|
||||
return provider;
|
||||
}
|
||||
|
||||
export function stream<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: ProviderStreamOptions,
|
||||
): AssistantMessageEventStream {
|
||||
const provider = resolveApiProvider(model.api);
|
||||
return provider.stream(model, context, options as StreamOptions);
|
||||
}
|
||||
|
||||
export async function complete<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: ProviderStreamOptions,
|
||||
): Promise<AssistantMessage> {
|
||||
const s = stream(model, context, options);
|
||||
return s.result();
|
||||
}
|
||||
|
||||
export function streamSimple<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream {
|
||||
const provider = resolveApiProvider(model.api);
|
||||
return provider.streamSimple(model, context, options);
|
||||
}
|
||||
|
||||
export async function completeSimple<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): Promise<AssistantMessage> {
|
||||
const s = streamSimple(model, context, options);
|
||||
return s.result();
|
||||
}
|
||||
337
packages/ai/src/types.ts
Normal file
337
packages/ai/src/types.ts
Normal file
|
|
@ -0,0 +1,337 @@
|
|||
import type { AssistantMessageEventStream } from "./utils/event-stream.js";
|
||||
|
||||
export type { AssistantMessageEventStream } from "./utils/event-stream.js";
|
||||
|
||||
export type KnownApi =
|
||||
| "openai-completions"
|
||||
| "mistral-conversations"
|
||||
| "openai-responses"
|
||||
| "azure-openai-responses"
|
||||
| "openai-codex-responses"
|
||||
| "anthropic-messages"
|
||||
| "bedrock-converse-stream"
|
||||
| "google-generative-ai"
|
||||
| "google-gemini-cli"
|
||||
| "google-vertex";
|
||||
|
||||
export type Api = KnownApi | (string & {});
|
||||
|
||||
export type KnownProvider =
|
||||
| "amazon-bedrock"
|
||||
| "anthropic"
|
||||
| "google"
|
||||
| "google-gemini-cli"
|
||||
| "google-antigravity"
|
||||
| "google-vertex"
|
||||
| "openai"
|
||||
| "azure-openai-responses"
|
||||
| "openai-codex"
|
||||
| "github-copilot"
|
||||
| "xai"
|
||||
| "groq"
|
||||
| "cerebras"
|
||||
| "openrouter"
|
||||
| "vercel-ai-gateway"
|
||||
| "zai"
|
||||
| "mistral"
|
||||
| "minimax"
|
||||
| "minimax-cn"
|
||||
| "huggingface"
|
||||
| "opencode"
|
||||
| "opencode-go"
|
||||
| "kimi-coding";
|
||||
export type Provider = KnownProvider | string;
|
||||
|
||||
export type ThinkingLevel = "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
|
||||
/** Token budgets for each thinking level (token-based providers only) */
|
||||
export interface ThinkingBudgets {
|
||||
minimal?: number;
|
||||
low?: number;
|
||||
medium?: number;
|
||||
high?: number;
|
||||
}
|
||||
|
||||
// Base options all providers share
|
||||
export type CacheRetention = "none" | "short" | "long";
|
||||
|
||||
export type Transport = "sse" | "websocket" | "auto";
|
||||
|
||||
export interface StreamOptions {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
apiKey?: string;
|
||||
/**
|
||||
* Preferred transport for providers that support multiple transports.
|
||||
* Providers that do not support this option ignore it.
|
||||
*/
|
||||
transport?: Transport;
|
||||
/**
|
||||
* Prompt cache retention preference. Providers map this to their supported values.
|
||||
* Default: "short".
|
||||
*/
|
||||
cacheRetention?: CacheRetention;
|
||||
/**
|
||||
* Optional session identifier for providers that support session-based caching.
|
||||
* Providers can use this to enable prompt caching, request routing, or other
|
||||
* session-aware features. Ignored by providers that don't support it.
|
||||
*/
|
||||
sessionId?: string;
|
||||
/**
|
||||
* Optional callback for inspecting or replacing provider payloads before sending.
|
||||
* Return undefined to keep the payload unchanged.
|
||||
*/
|
||||
onPayload?: (payload: unknown, model: Model<Api>) => unknown | undefined | Promise<unknown | undefined>;
|
||||
/**
|
||||
* Optional custom HTTP headers to include in API requests.
|
||||
* Merged with provider defaults; can override default headers.
|
||||
* Not supported by all providers (e.g., AWS Bedrock uses SDK auth).
|
||||
*/
|
||||
headers?: Record<string, string>;
|
||||
/**
|
||||
* Maximum delay in milliseconds to wait for a retry when the server requests a long wait.
|
||||
* If the server's requested delay exceeds this value, the request fails immediately
|
||||
* with an error containing the requested delay, allowing higher-level retry logic
|
||||
* to handle it with user visibility.
|
||||
* Default: 60000 (60 seconds). Set to 0 to disable the cap.
|
||||
*/
|
||||
maxRetryDelayMs?: number;
|
||||
/**
|
||||
* Optional metadata to include in API requests.
|
||||
* Providers extract the fields they understand and ignore the rest.
|
||||
* For example, Anthropic uses `user_id` for abuse tracking and rate limiting.
|
||||
*/
|
||||
metadata?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export type ProviderStreamOptions = StreamOptions & Record<string, unknown>;
|
||||
|
||||
// Unified options with reasoning passed to streamSimple() and completeSimple()
|
||||
export interface SimpleStreamOptions extends StreamOptions {
|
||||
reasoning?: ThinkingLevel;
|
||||
/** Custom token budgets for thinking levels (token-based providers only) */
|
||||
thinkingBudgets?: ThinkingBudgets;
|
||||
}
|
||||
|
||||
// Generic StreamFunction with typed options.
|
||||
//
|
||||
// Contract:
|
||||
// - Must return an AssistantMessageEventStream.
|
||||
// - Once invoked, request/model/runtime failures should be encoded in the
|
||||
// returned stream, not thrown.
|
||||
// - Error termination must produce an AssistantMessage with stopReason
|
||||
// "error" or "aborted" and errorMessage, emitted via the stream protocol.
|
||||
export type StreamFunction<TApi extends Api = Api, TOptions extends StreamOptions = StreamOptions> = (
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: TOptions,
|
||||
) => AssistantMessageEventStream;
|
||||
|
||||
export interface TextSignatureV1 {
|
||||
v: 1;
|
||||
id: string;
|
||||
phase?: "commentary" | "final_answer";
|
||||
}
|
||||
|
||||
export interface TextContent {
|
||||
type: "text";
|
||||
text: string;
|
||||
textSignature?: string; // e.g., for OpenAI responses, message metadata (legacy id string or TextSignatureV1 JSON)
|
||||
}
|
||||
|
||||
export interface ThinkingContent {
|
||||
type: "thinking";
|
||||
thinking: string;
|
||||
thinkingSignature?: string; // e.g., for OpenAI responses, the reasoning item ID
|
||||
/** When true, the thinking content was redacted by safety filters. The opaque
|
||||
* encrypted payload is stored in `thinkingSignature` so it can be passed back
|
||||
* to the API for multi-turn continuity. */
|
||||
redacted?: boolean;
|
||||
}
|
||||
|
||||
export interface ImageContent {
|
||||
type: "image";
|
||||
data: string; // base64 encoded image data
|
||||
mimeType: string; // e.g., "image/jpeg", "image/png"
|
||||
}
|
||||
|
||||
export interface ToolCall {
|
||||
type: "toolCall";
|
||||
id: string;
|
||||
name: string;
|
||||
arguments: Record<string, any>;
|
||||
thoughtSignature?: string; // Google-specific: opaque signature for reusing thought context
|
||||
}
|
||||
|
||||
export interface Usage {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
totalTokens: number;
|
||||
cost: {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
total: number;
|
||||
};
|
||||
}
|
||||
|
||||
export type StopReason = "stop" | "length" | "toolUse" | "error" | "aborted";
|
||||
|
||||
export interface UserMessage {
|
||||
role: "user";
|
||||
content: string | (TextContent | ImageContent)[];
|
||||
timestamp: number; // Unix timestamp in milliseconds
|
||||
}
|
||||
|
||||
export interface AssistantMessage {
|
||||
role: "assistant";
|
||||
content: (TextContent | ThinkingContent | ToolCall)[];
|
||||
api: Api;
|
||||
provider: Provider;
|
||||
model: string;
|
||||
responseId?: string; // Provider-specific response/message identifier when the upstream API exposes one
|
||||
usage: Usage;
|
||||
stopReason: StopReason;
|
||||
errorMessage?: string;
|
||||
timestamp: number; // Unix timestamp in milliseconds
|
||||
}
|
||||
|
||||
export interface ToolResultMessage<TDetails = any> {
|
||||
role: "toolResult";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
content: (TextContent | ImageContent)[]; // Supports text and images
|
||||
details?: TDetails;
|
||||
isError: boolean;
|
||||
timestamp: number; // Unix timestamp in milliseconds
|
||||
}
|
||||
|
||||
export type Message = UserMessage | AssistantMessage | ToolResultMessage;
|
||||
|
||||
import type { TSchema } from "@sinclair/typebox";
|
||||
|
||||
export interface Tool<TParameters extends TSchema = TSchema> {
|
||||
name: string;
|
||||
description: string;
|
||||
parameters: TParameters;
|
||||
}
|
||||
|
||||
export interface Context {
|
||||
systemPrompt?: string;
|
||||
messages: Message[];
|
||||
tools?: Tool[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Event protocol for AssistantMessageEventStream.
|
||||
*
|
||||
* Streams should emit `start` before partial updates, then terminate with either:
|
||||
* - `done` carrying the final successful AssistantMessage, or
|
||||
* - `error` carrying the final AssistantMessage with stopReason "error" or "aborted"
|
||||
* and errorMessage.
|
||||
*/
|
||||
export type AssistantMessageEvent =
|
||||
| { type: "start"; partial: AssistantMessage }
|
||||
| { type: "text_start"; contentIndex: number; partial: AssistantMessage }
|
||||
| { type: "text_delta"; contentIndex: number; delta: string; partial: AssistantMessage }
|
||||
| { type: "text_end"; contentIndex: number; content: string; partial: AssistantMessage }
|
||||
| { type: "thinking_start"; contentIndex: number; partial: AssistantMessage }
|
||||
| { type: "thinking_delta"; contentIndex: number; delta: string; partial: AssistantMessage }
|
||||
| { type: "thinking_end"; contentIndex: number; content: string; partial: AssistantMessage }
|
||||
| { type: "toolcall_start"; contentIndex: number; partial: AssistantMessage }
|
||||
| { type: "toolcall_delta"; contentIndex: number; delta: string; partial: AssistantMessage }
|
||||
| { type: "toolcall_end"; contentIndex: number; toolCall: ToolCall; partial: AssistantMessage }
|
||||
| { type: "done"; reason: Extract<StopReason, "stop" | "length" | "toolUse">; message: AssistantMessage }
|
||||
| { type: "error"; reason: Extract<StopReason, "aborted" | "error">; error: AssistantMessage };
|
||||
|
||||
/**
|
||||
* Compatibility settings for OpenAI-compatible completions APIs.
|
||||
* Use this to override URL-based auto-detection for custom providers.
|
||||
*/
|
||||
export interface OpenAICompletionsCompat {
|
||||
/** Whether the provider supports the `store` field. Default: auto-detected from URL. */
|
||||
supportsStore?: boolean;
|
||||
/** Whether the provider supports the `developer` role (vs `system`). Default: auto-detected from URL. */
|
||||
supportsDeveloperRole?: boolean;
|
||||
/** Whether the provider supports `reasoning_effort`. Default: auto-detected from URL. */
|
||||
supportsReasoningEffort?: boolean;
|
||||
/** Optional mapping from pi-ai reasoning levels to provider/model-specific `reasoning_effort` values. */
|
||||
reasoningEffortMap?: Partial<Record<ThinkingLevel, string>>;
|
||||
/** Whether the provider supports `stream_options: { include_usage: true }` for token usage in streaming responses. Default: true. */
|
||||
supportsUsageInStreaming?: boolean;
|
||||
/** Which field to use for max tokens. Default: auto-detected from URL. */
|
||||
maxTokensField?: "max_completion_tokens" | "max_tokens";
|
||||
/** Whether tool results require the `name` field. Default: auto-detected from URL. */
|
||||
requiresToolResultName?: boolean;
|
||||
/** Whether a user message after tool results requires an assistant message in between. Default: auto-detected from URL. */
|
||||
requiresAssistantAfterToolResult?: boolean;
|
||||
/** Whether thinking blocks must be converted to text blocks with <thinking> delimiters. Default: auto-detected from URL. */
|
||||
requiresThinkingAsText?: boolean;
|
||||
/** Format for reasoning/thinking parameter. "openai" uses reasoning_effort, "openrouter" uses reasoning: { effort }, "zai" uses top-level enable_thinking: boolean, "qwen" uses top-level enable_thinking: boolean, and "qwen-chat-template" uses chat_template_kwargs.enable_thinking. Default: "openai". */
|
||||
thinkingFormat?: "openai" | "openrouter" | "zai" | "qwen" | "qwen-chat-template";
|
||||
/** OpenRouter-specific routing preferences. Only used when baseUrl points to OpenRouter. */
|
||||
openRouterRouting?: OpenRouterRouting;
|
||||
/** Vercel AI Gateway routing preferences. Only used when baseUrl points to Vercel AI Gateway. */
|
||||
vercelGatewayRouting?: VercelGatewayRouting;
|
||||
/** Whether the provider supports the `strict` field in tool definitions. Default: true. */
|
||||
supportsStrictMode?: boolean;
|
||||
}
|
||||
|
||||
/** Compatibility settings for OpenAI Responses APIs. */
|
||||
export interface OpenAIResponsesCompat {
|
||||
// Reserved for future use
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenRouter provider routing preferences.
|
||||
* Controls which upstream providers OpenRouter routes requests to.
|
||||
* @see https://openrouter.ai/docs/provider-routing
|
||||
*/
|
||||
export interface OpenRouterRouting {
|
||||
/** List of provider slugs to exclusively use for this request (e.g., ["amazon-bedrock", "anthropic"]). */
|
||||
only?: string[];
|
||||
/** List of provider slugs to try in order (e.g., ["anthropic", "openai"]). */
|
||||
order?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Vercel AI Gateway routing preferences.
|
||||
* Controls which upstream providers the gateway routes requests to.
|
||||
* @see https://vercel.com/docs/ai-gateway/models-and-providers/provider-options
|
||||
*/
|
||||
export interface VercelGatewayRouting {
|
||||
/** List of provider slugs to exclusively use for this request (e.g., ["bedrock", "anthropic"]). */
|
||||
only?: string[];
|
||||
/** List of provider slugs to try in order (e.g., ["anthropic", "openai"]). */
|
||||
order?: string[];
|
||||
}
|
||||
|
||||
// Model interface for the unified model system
|
||||
export interface Model<TApi extends Api> {
|
||||
id: string;
|
||||
name: string;
|
||||
api: TApi;
|
||||
provider: Provider;
|
||||
baseUrl: string;
|
||||
reasoning: boolean;
|
||||
input: ("text" | "image")[];
|
||||
cost: {
|
||||
input: number; // $/million tokens
|
||||
output: number; // $/million tokens
|
||||
cacheRead: number; // $/million tokens
|
||||
cacheWrite: number; // $/million tokens
|
||||
};
|
||||
contextWindow: number;
|
||||
maxTokens: number;
|
||||
headers?: Record<string, string>;
|
||||
/** Compatibility overrides for OpenAI-compatible APIs. If not set, auto-detected from baseUrl. */
|
||||
compat?: TApi extends "openai-completions"
|
||||
? OpenAICompletionsCompat
|
||||
: TApi extends "openai-responses"
|
||||
? OpenAIResponsesCompat
|
||||
: never;
|
||||
}
|
||||
87
packages/ai/src/utils/event-stream.ts
Normal file
87
packages/ai/src/utils/event-stream.ts
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
import type { AssistantMessage, AssistantMessageEvent } from "../types.js";
|
||||
|
||||
// Generic event stream class for async iteration
|
||||
export class EventStream<T, R = T> implements AsyncIterable<T> {
|
||||
private queue: T[] = [];
|
||||
private waiting: ((value: IteratorResult<T>) => void)[] = [];
|
||||
private done = false;
|
||||
private finalResultPromise: Promise<R>;
|
||||
private resolveFinalResult!: (result: R) => void;
|
||||
|
||||
constructor(
|
||||
private isComplete: (event: T) => boolean,
|
||||
private extractResult: (event: T) => R,
|
||||
) {
|
||||
this.finalResultPromise = new Promise((resolve) => {
|
||||
this.resolveFinalResult = resolve;
|
||||
});
|
||||
}
|
||||
|
||||
push(event: T): void {
|
||||
if (this.done) return;
|
||||
|
||||
if (this.isComplete(event)) {
|
||||
this.done = true;
|
||||
this.resolveFinalResult(this.extractResult(event));
|
||||
}
|
||||
|
||||
// Deliver to waiting consumer or queue it
|
||||
const waiter = this.waiting.shift();
|
||||
if (waiter) {
|
||||
waiter({ value: event, done: false });
|
||||
} else {
|
||||
this.queue.push(event);
|
||||
}
|
||||
}
|
||||
|
||||
end(result?: R): void {
|
||||
this.done = true;
|
||||
if (result !== undefined) {
|
||||
this.resolveFinalResult(result);
|
||||
}
|
||||
// Notify all waiting consumers that we're done
|
||||
while (this.waiting.length > 0) {
|
||||
const waiter = this.waiting.shift()!;
|
||||
waiter({ value: undefined as any, done: true });
|
||||
}
|
||||
}
|
||||
|
||||
async *[Symbol.asyncIterator](): AsyncIterator<T> {
|
||||
while (true) {
|
||||
if (this.queue.length > 0) {
|
||||
yield this.queue.shift()!;
|
||||
} else if (this.done) {
|
||||
return;
|
||||
} else {
|
||||
const result = await new Promise<IteratorResult<T>>((resolve) => this.waiting.push(resolve));
|
||||
if (result.done) return;
|
||||
yield result.value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result(): Promise<R> {
|
||||
return this.finalResultPromise;
|
||||
}
|
||||
}
|
||||
|
||||
export class AssistantMessageEventStream extends EventStream<AssistantMessageEvent, AssistantMessage> {
|
||||
constructor() {
|
||||
super(
|
||||
(event) => event.type === "done" || event.type === "error",
|
||||
(event) => {
|
||||
if (event.type === "done") {
|
||||
return event.message;
|
||||
} else if (event.type === "error") {
|
||||
return event.error;
|
||||
}
|
||||
throw new Error("Unexpected event type for final result");
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/** Factory function for AssistantMessageEventStream (for use in extensions) */
|
||||
export function createAssistantMessageEventStream(): AssistantMessageEventStream {
|
||||
return new AssistantMessageEventStream();
|
||||
}
|
||||
13
packages/ai/src/utils/hash.ts
Normal file
13
packages/ai/src/utils/hash.ts
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
/** Fast deterministic hash to shorten long strings */
|
||||
export function shortHash(str: string): string {
|
||||
let h1 = 0xdeadbeef;
|
||||
let h2 = 0x41c6ce57;
|
||||
for (let i = 0; i < str.length; i++) {
|
||||
const ch = str.charCodeAt(i);
|
||||
h1 = Math.imul(h1 ^ ch, 2654435761);
|
||||
h2 = Math.imul(h2 ^ ch, 1597334677);
|
||||
}
|
||||
h1 = Math.imul(h1 ^ (h1 >>> 16), 2246822507) ^ Math.imul(h2 ^ (h2 >>> 13), 3266489909);
|
||||
h2 = Math.imul(h2 ^ (h2 >>> 16), 2246822507) ^ Math.imul(h1 ^ (h1 >>> 13), 3266489909);
|
||||
return (h2 >>> 0).toString(36) + (h1 >>> 0).toString(36);
|
||||
}
|
||||
28
packages/ai/src/utils/json-parse.ts
Normal file
28
packages/ai/src/utils/json-parse.ts
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
import { parse as partialParse } from "partial-json";
|
||||
|
||||
/**
|
||||
* Attempts to parse potentially incomplete JSON during streaming.
|
||||
* Always returns a valid object, even if the JSON is incomplete.
|
||||
*
|
||||
* @param partialJson The partial JSON string from streaming
|
||||
* @returns Parsed object or empty object if parsing fails
|
||||
*/
|
||||
export function parseStreamingJson<T = any>(partialJson: string | undefined): T {
|
||||
if (!partialJson || partialJson.trim() === "") {
|
||||
return {} as T;
|
||||
}
|
||||
|
||||
// Try standard parsing first (fastest for complete JSON)
|
||||
try {
|
||||
return JSON.parse(partialJson) as T;
|
||||
} catch {
|
||||
// Try partial-json for incomplete JSON
|
||||
try {
|
||||
const result = partialParse(partialJson);
|
||||
return (result ?? {}) as T;
|
||||
} catch {
|
||||
// If all parsing fails, return empty object
|
||||
return {} as T;
|
||||
}
|
||||
}
|
||||
}
|
||||
402
packages/ai/src/utils/oauth/anthropic.ts
Normal file
402
packages/ai/src/utils/oauth/anthropic.ts
Normal file
|
|
@ -0,0 +1,402 @@
|
|||
/**
|
||||
* Anthropic OAuth flow (Claude Pro/Max)
|
||||
*
|
||||
* NOTE: This module uses Node.js http.createServer for the OAuth callback server.
|
||||
* It is only intended for CLI use, not browser environments.
|
||||
*/
|
||||
|
||||
import type { Server } from "node:http";
|
||||
import { oauthErrorHtml, oauthSuccessHtml } from "./oauth-page.js";
|
||||
import { generatePKCE } from "./pkce.js";
|
||||
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthPrompt, OAuthProviderInterface } from "./types.js";
|
||||
|
||||
type CallbackServerInfo = {
|
||||
server: Server;
|
||||
redirectUri: string;
|
||||
cancelWait: () => void;
|
||||
waitForCode: () => Promise<{ code: string; state: string } | null>;
|
||||
};
|
||||
|
||||
type NodeApis = {
|
||||
createServer: typeof import("node:http").createServer;
|
||||
};
|
||||
|
||||
let nodeApis: NodeApis | null = null;
|
||||
let nodeApisPromise: Promise<NodeApis> | null = null;
|
||||
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl");
|
||||
const AUTHORIZE_URL = "https://claude.ai/oauth/authorize";
|
||||
const TOKEN_URL = "https://platform.claude.com/v1/oauth/token";
|
||||
const CALLBACK_HOST = "127.0.0.1";
|
||||
const CALLBACK_PORT = 53692;
|
||||
const CALLBACK_PATH = "/callback";
|
||||
const REDIRECT_URI = `http://localhost:${CALLBACK_PORT}${CALLBACK_PATH}`;
|
||||
const SCOPES =
|
||||
"org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload";
|
||||
async function getNodeApis(): Promise<NodeApis> {
|
||||
if (nodeApis) return nodeApis;
|
||||
if (!nodeApisPromise) {
|
||||
if (typeof process === "undefined" || (!process.versions?.node && !process.versions?.bun)) {
|
||||
throw new Error("Anthropic OAuth is only available in Node.js environments");
|
||||
}
|
||||
nodeApisPromise = import("node:http").then((httpModule) => ({
|
||||
createServer: httpModule.createServer,
|
||||
}));
|
||||
}
|
||||
nodeApis = await nodeApisPromise;
|
||||
return nodeApis;
|
||||
}
|
||||
|
||||
function parseAuthorizationInput(input: string): { code?: string; state?: string } {
|
||||
const value = input.trim();
|
||||
if (!value) return {};
|
||||
|
||||
try {
|
||||
const url = new URL(value);
|
||||
return {
|
||||
code: url.searchParams.get("code") ?? undefined,
|
||||
state: url.searchParams.get("state") ?? undefined,
|
||||
};
|
||||
} catch {
|
||||
// not a URL
|
||||
}
|
||||
|
||||
if (value.includes("#")) {
|
||||
const [code, state] = value.split("#", 2);
|
||||
return { code, state };
|
||||
}
|
||||
|
||||
if (value.includes("code=")) {
|
||||
const params = new URLSearchParams(value);
|
||||
return {
|
||||
code: params.get("code") ?? undefined,
|
||||
state: params.get("state") ?? undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return { code: value };
|
||||
}
|
||||
|
||||
function formatErrorDetails(error: unknown): string {
|
||||
if (error instanceof Error) {
|
||||
const details: string[] = [`${error.name}: ${error.message}`];
|
||||
const errorWithCode = error as Error & { code?: string; errno?: number | string; cause?: unknown };
|
||||
if (errorWithCode.code) details.push(`code=${errorWithCode.code}`);
|
||||
if (typeof errorWithCode.errno !== "undefined") details.push(`errno=${String(errorWithCode.errno)}`);
|
||||
if (typeof error.cause !== "undefined") {
|
||||
details.push(`cause=${formatErrorDetails(error.cause)}`);
|
||||
}
|
||||
if (error.stack) {
|
||||
details.push(`stack=${error.stack}`);
|
||||
}
|
||||
return details.join("; ");
|
||||
}
|
||||
return String(error);
|
||||
}
|
||||
|
||||
async function startCallbackServer(expectedState: string): Promise<CallbackServerInfo> {
|
||||
const { createServer } = await getNodeApis();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
let settleWait: ((value: { code: string; state: string } | null) => void) | undefined;
|
||||
const waitForCodePromise = new Promise<{ code: string; state: string } | null>((resolveWait) => {
|
||||
let settled = false;
|
||||
settleWait = (value) => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
resolveWait(value);
|
||||
};
|
||||
});
|
||||
|
||||
const server = createServer((req, res) => {
|
||||
try {
|
||||
const url = new URL(req.url || "", "http://localhost");
|
||||
if (url.pathname !== CALLBACK_PATH) {
|
||||
res.writeHead(404, { "Content-Type": "text/html; charset=utf-8" });
|
||||
res.end(oauthErrorHtml("Callback route not found."));
|
||||
return;
|
||||
}
|
||||
|
||||
const code = url.searchParams.get("code");
|
||||
const state = url.searchParams.get("state");
|
||||
const error = url.searchParams.get("error");
|
||||
|
||||
if (error) {
|
||||
res.writeHead(400, { "Content-Type": "text/html; charset=utf-8" });
|
||||
res.end(oauthErrorHtml("Anthropic authentication did not complete.", `Error: ${error}`));
|
||||
return;
|
||||
}
|
||||
|
||||
if (!code || !state) {
|
||||
res.writeHead(400, { "Content-Type": "text/html; charset=utf-8" });
|
||||
res.end(oauthErrorHtml("Missing code or state parameter."));
|
||||
return;
|
||||
}
|
||||
|
||||
if (state !== expectedState) {
|
||||
res.writeHead(400, { "Content-Type": "text/html; charset=utf-8" });
|
||||
res.end(oauthErrorHtml("State mismatch."));
|
||||
return;
|
||||
}
|
||||
|
||||
res.writeHead(200, { "Content-Type": "text/html; charset=utf-8" });
|
||||
res.end(oauthSuccessHtml("Anthropic authentication completed. You can close this window."));
|
||||
settleWait?.({ code, state });
|
||||
} catch {
|
||||
res.writeHead(500, { "Content-Type": "text/plain; charset=utf-8" });
|
||||
res.end("Internal error");
|
||||
}
|
||||
});
|
||||
|
||||
server.on("error", (err) => {
|
||||
reject(err);
|
||||
});
|
||||
|
||||
server.listen(CALLBACK_PORT, CALLBACK_HOST, () => {
|
||||
resolve({
|
||||
server,
|
||||
redirectUri: REDIRECT_URI,
|
||||
cancelWait: () => {
|
||||
settleWait?.(null);
|
||||
},
|
||||
waitForCode: () => waitForCodePromise,
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
async function postJson(url: string, body: Record<string, string | number>): Promise<string> {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Accept: "application/json",
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
signal: AbortSignal.timeout(30_000),
|
||||
});
|
||||
|
||||
const responseBody = await response.text();
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP request failed. status=${response.status}; url=${url}; body=${responseBody}`);
|
||||
}
|
||||
|
||||
return responseBody;
|
||||
}
|
||||
|
||||
async function exchangeAuthorizationCode(
|
||||
code: string,
|
||||
state: string,
|
||||
verifier: string,
|
||||
redirectUri: string,
|
||||
): Promise<OAuthCredentials> {
|
||||
let responseBody: string;
|
||||
try {
|
||||
responseBody = await postJson(TOKEN_URL, {
|
||||
grant_type: "authorization_code",
|
||||
client_id: CLIENT_ID,
|
||||
code,
|
||||
state,
|
||||
redirect_uri: redirectUri,
|
||||
code_verifier: verifier,
|
||||
});
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Token exchange request failed. url=${TOKEN_URL}; redirect_uri=${redirectUri}; response_type=authorization_code; details=${formatErrorDetails(error)}`,
|
||||
);
|
||||
}
|
||||
|
||||
let tokenData: { access_token: string; refresh_token: string; expires_in: number };
|
||||
try {
|
||||
tokenData = JSON.parse(responseBody) as { access_token: string; refresh_token: string; expires_in: number };
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Token exchange returned invalid JSON. url=${TOKEN_URL}; body=${responseBody}; details=${formatErrorDetails(error)}`,
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
refresh: tokenData.refresh_token,
|
||||
access: tokenData.access_token,
|
||||
expires: Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Login with Anthropic OAuth (authorization code + PKCE)
|
||||
*/
|
||||
export async function loginAnthropic(options: {
|
||||
onAuth: (info: { url: string; instructions?: string }) => void;
|
||||
onPrompt: (prompt: OAuthPrompt) => Promise<string>;
|
||||
onProgress?: (message: string) => void;
|
||||
onManualCodeInput?: () => Promise<string>;
|
||||
}): Promise<OAuthCredentials> {
|
||||
const { verifier, challenge } = await generatePKCE();
|
||||
const server = await startCallbackServer(verifier);
|
||||
|
||||
let code: string | undefined;
|
||||
let state: string | undefined;
|
||||
let redirectUriForExchange = REDIRECT_URI;
|
||||
|
||||
try {
|
||||
const authParams = new URLSearchParams({
|
||||
code: "true",
|
||||
client_id: CLIENT_ID,
|
||||
response_type: "code",
|
||||
redirect_uri: REDIRECT_URI,
|
||||
scope: SCOPES,
|
||||
code_challenge: challenge,
|
||||
code_challenge_method: "S256",
|
||||
state: verifier,
|
||||
});
|
||||
|
||||
options.onAuth({
|
||||
url: `${AUTHORIZE_URL}?${authParams.toString()}`,
|
||||
instructions:
|
||||
"Complete login in your browser. If the browser is on another machine, paste the final redirect URL here.",
|
||||
});
|
||||
|
||||
if (options.onManualCodeInput) {
|
||||
let manualInput: string | undefined;
|
||||
let manualError: Error | undefined;
|
||||
const manualPromise = options
|
||||
.onManualCodeInput()
|
||||
.then((input) => {
|
||||
manualInput = input;
|
||||
server.cancelWait();
|
||||
})
|
||||
.catch((err) => {
|
||||
manualError = err instanceof Error ? err : new Error(String(err));
|
||||
server.cancelWait();
|
||||
});
|
||||
|
||||
const result = await server.waitForCode();
|
||||
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
|
||||
if (result?.code) {
|
||||
code = result.code;
|
||||
state = result.state;
|
||||
redirectUriForExchange = REDIRECT_URI;
|
||||
} else if (manualInput) {
|
||||
const parsed = parseAuthorizationInput(manualInput);
|
||||
if (parsed.state && parsed.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch");
|
||||
}
|
||||
code = parsed.code;
|
||||
state = parsed.state ?? verifier;
|
||||
}
|
||||
|
||||
if (!code) {
|
||||
await manualPromise;
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
if (manualInput) {
|
||||
const parsed = parseAuthorizationInput(manualInput);
|
||||
if (parsed.state && parsed.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch");
|
||||
}
|
||||
code = parsed.code;
|
||||
state = parsed.state ?? verifier;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const result = await server.waitForCode();
|
||||
if (result?.code) {
|
||||
code = result.code;
|
||||
state = result.state;
|
||||
redirectUriForExchange = REDIRECT_URI;
|
||||
}
|
||||
}
|
||||
|
||||
if (!code) {
|
||||
const input = await options.onPrompt({
|
||||
message: "Paste the authorization code or full redirect URL:",
|
||||
placeholder: REDIRECT_URI,
|
||||
});
|
||||
const parsed = parseAuthorizationInput(input);
|
||||
if (parsed.state && parsed.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch");
|
||||
}
|
||||
code = parsed.code;
|
||||
state = parsed.state ?? verifier;
|
||||
}
|
||||
|
||||
if (!code) {
|
||||
throw new Error("Missing authorization code");
|
||||
}
|
||||
|
||||
if (!state) {
|
||||
throw new Error("Missing OAuth state");
|
||||
}
|
||||
|
||||
options.onProgress?.("Exchanging authorization code for tokens...");
|
||||
return exchangeAuthorizationCode(code, state, verifier, redirectUriForExchange);
|
||||
} finally {
|
||||
server.server.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh Anthropic OAuth token
|
||||
*/
|
||||
export async function refreshAnthropicToken(refreshToken: string): Promise<OAuthCredentials> {
|
||||
let responseBody: string;
|
||||
try {
|
||||
responseBody = await postJson(TOKEN_URL, {
|
||||
grant_type: "refresh_token",
|
||||
client_id: CLIENT_ID,
|
||||
refresh_token: refreshToken,
|
||||
});
|
||||
} catch (error) {
|
||||
throw new Error(`Anthropic token refresh request failed. url=${TOKEN_URL}; details=${formatErrorDetails(error)}`);
|
||||
}
|
||||
|
||||
let data: { access_token: string; refresh_token: string; expires_in: number; scope?: string };
|
||||
try {
|
||||
data = JSON.parse(responseBody) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
expires_in: number;
|
||||
scope?: string;
|
||||
};
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Anthropic token refresh returned invalid JSON. url=${TOKEN_URL}; body=${responseBody}; details=${formatErrorDetails(error)}`,
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
refresh: data.refresh_token,
|
||||
access: data.access_token,
|
||||
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
|
||||
};
|
||||
}
|
||||
|
||||
export const anthropicOAuthProvider: OAuthProviderInterface = {
|
||||
id: "anthropic",
|
||||
name: "Anthropic (Claude Pro/Max)",
|
||||
usesCallbackServer: true,
|
||||
|
||||
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
return loginAnthropic({
|
||||
onAuth: callbacks.onAuth,
|
||||
onPrompt: callbacks.onPrompt,
|
||||
onProgress: callbacks.onProgress,
|
||||
onManualCodeInput: callbacks.onManualCodeInput,
|
||||
});
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
return refreshAnthropicToken(credentials.refresh);
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
return credentials.access;
|
||||
},
|
||||
};
|
||||
396
packages/ai/src/utils/oauth/github-copilot.ts
Normal file
396
packages/ai/src/utils/oauth/github-copilot.ts
Normal file
|
|
@ -0,0 +1,396 @@
|
|||
/**
|
||||
* GitHub Copilot OAuth flow
|
||||
*/
|
||||
|
||||
import { getModels } from "../../models.js";
|
||||
import type { Api, Model } from "../../types.js";
|
||||
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
|
||||
|
||||
type CopilotCredentials = OAuthCredentials & {
|
||||
enterpriseUrl?: string;
|
||||
};
|
||||
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode("SXYxLmI1MDdhMDhjODdlY2ZlOTg=");
|
||||
|
||||
const COPILOT_HEADERS = {
|
||||
"User-Agent": "GitHubCopilotChat/0.35.0",
|
||||
"Editor-Version": "vscode/1.107.0",
|
||||
"Editor-Plugin-Version": "copilot-chat/0.35.0",
|
||||
"Copilot-Integration-Id": "vscode-chat",
|
||||
} as const;
|
||||
|
||||
const INITIAL_POLL_INTERVAL_MULTIPLIER = 1.2;
|
||||
const SLOW_DOWN_POLL_INTERVAL_MULTIPLIER = 1.4;
|
||||
|
||||
type DeviceCodeResponse = {
|
||||
device_code: string;
|
||||
user_code: string;
|
||||
verification_uri: string;
|
||||
interval: number;
|
||||
expires_in: number;
|
||||
};
|
||||
|
||||
type DeviceTokenSuccessResponse = {
|
||||
access_token: string;
|
||||
token_type?: string;
|
||||
scope?: string;
|
||||
};
|
||||
|
||||
type DeviceTokenErrorResponse = {
|
||||
error: string;
|
||||
error_description?: string;
|
||||
interval?: number;
|
||||
};
|
||||
|
||||
export function normalizeDomain(input: string): string | null {
|
||||
const trimmed = input.trim();
|
||||
if (!trimmed) return null;
|
||||
try {
|
||||
const url = trimmed.includes("://") ? new URL(trimmed) : new URL(`https://${trimmed}`);
|
||||
return url.hostname;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function getUrls(domain: string): {
|
||||
deviceCodeUrl: string;
|
||||
accessTokenUrl: string;
|
||||
copilotTokenUrl: string;
|
||||
} {
|
||||
return {
|
||||
deviceCodeUrl: `https://${domain}/login/device/code`,
|
||||
accessTokenUrl: `https://${domain}/login/oauth/access_token`,
|
||||
copilotTokenUrl: `https://api.${domain}/copilot_internal/v2/token`,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse the proxy-ep from a Copilot token and convert to API base URL.
|
||||
* Token format: tid=...;exp=...;proxy-ep=proxy.individual.githubcopilot.com;...
|
||||
* Returns API URL like https://api.individual.githubcopilot.com
|
||||
*/
|
||||
function getBaseUrlFromToken(token: string): string | null {
|
||||
const match = token.match(/proxy-ep=([^;]+)/);
|
||||
if (!match) return null;
|
||||
const proxyHost = match[1];
|
||||
// Convert proxy.xxx to api.xxx
|
||||
const apiHost = proxyHost.replace(/^proxy\./, "api.");
|
||||
return `https://${apiHost}`;
|
||||
}
|
||||
|
||||
export function getGitHubCopilotBaseUrl(token?: string, enterpriseDomain?: string): string {
|
||||
// If we have a token, extract the base URL from proxy-ep
|
||||
if (token) {
|
||||
const urlFromToken = getBaseUrlFromToken(token);
|
||||
if (urlFromToken) return urlFromToken;
|
||||
}
|
||||
// Fallback for enterprise or if token parsing fails
|
||||
if (enterpriseDomain) return `https://copilot-api.${enterpriseDomain}`;
|
||||
return "https://api.individual.githubcopilot.com";
|
||||
}
|
||||
|
||||
async function fetchJson(url: string, init: RequestInit): Promise<unknown> {
|
||||
const response = await fetch(url, init);
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
throw new Error(`${response.status} ${response.statusText}: ${text}`);
|
||||
}
|
||||
return response.json();
|
||||
}
|
||||
|
||||
async function startDeviceFlow(domain: string): Promise<DeviceCodeResponse> {
|
||||
const urls = getUrls(domain);
|
||||
const data = await fetchJson(urls.deviceCodeUrl, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Accept: "application/json",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": "GitHubCopilotChat/0.35.0",
|
||||
},
|
||||
body: new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
scope: "read:user",
|
||||
}),
|
||||
});
|
||||
|
||||
if (!data || typeof data !== "object") {
|
||||
throw new Error("Invalid device code response");
|
||||
}
|
||||
|
||||
const deviceCode = (data as Record<string, unknown>).device_code;
|
||||
const userCode = (data as Record<string, unknown>).user_code;
|
||||
const verificationUri = (data as Record<string, unknown>).verification_uri;
|
||||
const interval = (data as Record<string, unknown>).interval;
|
||||
const expiresIn = (data as Record<string, unknown>).expires_in;
|
||||
|
||||
if (
|
||||
typeof deviceCode !== "string" ||
|
||||
typeof userCode !== "string" ||
|
||||
typeof verificationUri !== "string" ||
|
||||
typeof interval !== "number" ||
|
||||
typeof expiresIn !== "number"
|
||||
) {
|
||||
throw new Error("Invalid device code response fields");
|
||||
}
|
||||
|
||||
return {
|
||||
device_code: deviceCode,
|
||||
user_code: userCode,
|
||||
verification_uri: verificationUri,
|
||||
interval,
|
||||
expires_in: expiresIn,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Sleep that can be interrupted by an AbortSignal
|
||||
*/
|
||||
function abortableSleep(ms: number, signal?: AbortSignal): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Login cancelled"));
|
||||
return;
|
||||
}
|
||||
|
||||
const timeout = setTimeout(resolve, ms);
|
||||
|
||||
signal?.addEventListener(
|
||||
"abort",
|
||||
() => {
|
||||
clearTimeout(timeout);
|
||||
reject(new Error("Login cancelled"));
|
||||
},
|
||||
{ once: true },
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
async function pollForGitHubAccessToken(
|
||||
domain: string,
|
||||
deviceCode: string,
|
||||
intervalSeconds: number,
|
||||
expiresIn: number,
|
||||
signal?: AbortSignal,
|
||||
) {
|
||||
const urls = getUrls(domain);
|
||||
const deadline = Date.now() + expiresIn * 1000;
|
||||
let intervalMs = Math.max(1000, Math.floor(intervalSeconds * 1000));
|
||||
let intervalMultiplier = INITIAL_POLL_INTERVAL_MULTIPLIER;
|
||||
let slowDownResponses = 0;
|
||||
|
||||
while (Date.now() < deadline) {
|
||||
if (signal?.aborted) {
|
||||
throw new Error("Login cancelled");
|
||||
}
|
||||
|
||||
const remainingMs = deadline - Date.now();
|
||||
const waitMs = Math.min(Math.ceil(intervalMs * intervalMultiplier), remainingMs);
|
||||
await abortableSleep(waitMs, signal);
|
||||
|
||||
const raw = await fetchJson(urls.accessTokenUrl, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Accept: "application/json",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": "GitHubCopilotChat/0.35.0",
|
||||
},
|
||||
body: new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
device_code: deviceCode,
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
|
||||
}),
|
||||
});
|
||||
|
||||
if (raw && typeof raw === "object" && typeof (raw as DeviceTokenSuccessResponse).access_token === "string") {
|
||||
return (raw as DeviceTokenSuccessResponse).access_token;
|
||||
}
|
||||
|
||||
if (raw && typeof raw === "object" && typeof (raw as DeviceTokenErrorResponse).error === "string") {
|
||||
const { error, error_description: description, interval } = raw as DeviceTokenErrorResponse;
|
||||
if (error === "authorization_pending") {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (error === "slow_down") {
|
||||
slowDownResponses += 1;
|
||||
intervalMs =
|
||||
typeof interval === "number" && interval > 0 ? interval * 1000 : Math.max(1000, intervalMs + 5000);
|
||||
intervalMultiplier = SLOW_DOWN_POLL_INTERVAL_MULTIPLIER;
|
||||
continue;
|
||||
}
|
||||
|
||||
const descriptionSuffix = description ? `: ${description}` : "";
|
||||
throw new Error(`Device flow failed: ${error}${descriptionSuffix}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (slowDownResponses > 0) {
|
||||
throw new Error(
|
||||
"Device flow timed out after one or more slow_down responses. This is often caused by clock drift in WSL or VM environments. Please sync or restart the VM clock and try again.",
|
||||
);
|
||||
}
|
||||
|
||||
throw new Error("Device flow timed out");
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh GitHub Copilot token
|
||||
*/
|
||||
export async function refreshGitHubCopilotToken(
|
||||
refreshToken: string,
|
||||
enterpriseDomain?: string,
|
||||
): Promise<OAuthCredentials> {
|
||||
const domain = enterpriseDomain || "github.com";
|
||||
const urls = getUrls(domain);
|
||||
|
||||
const raw = await fetchJson(urls.copilotTokenUrl, {
|
||||
headers: {
|
||||
Accept: "application/json",
|
||||
Authorization: `Bearer ${refreshToken}`,
|
||||
...COPILOT_HEADERS,
|
||||
},
|
||||
});
|
||||
|
||||
if (!raw || typeof raw !== "object") {
|
||||
throw new Error("Invalid Copilot token response");
|
||||
}
|
||||
|
||||
const token = (raw as Record<string, unknown>).token;
|
||||
const expiresAt = (raw as Record<string, unknown>).expires_at;
|
||||
|
||||
if (typeof token !== "string" || typeof expiresAt !== "number") {
|
||||
throw new Error("Invalid Copilot token response fields");
|
||||
}
|
||||
|
||||
return {
|
||||
refresh: refreshToken,
|
||||
access: token,
|
||||
expires: expiresAt * 1000 - 5 * 60 * 1000,
|
||||
enterpriseUrl: enterpriseDomain,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable a model for the user's GitHub Copilot account.
|
||||
* This is required for some models (like Claude, Grok) before they can be used.
|
||||
*/
|
||||
async function enableGitHubCopilotModel(token: string, modelId: string, enterpriseDomain?: string): Promise<boolean> {
|
||||
const baseUrl = getGitHubCopilotBaseUrl(token, enterpriseDomain);
|
||||
const url = `${baseUrl}/models/${modelId}/policy`;
|
||||
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${token}`,
|
||||
...COPILOT_HEADERS,
|
||||
"openai-intent": "chat-policy",
|
||||
"x-interaction-type": "chat-policy",
|
||||
},
|
||||
body: JSON.stringify({ state: "enabled" }),
|
||||
});
|
||||
return response.ok;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable all known GitHub Copilot models that may require policy acceptance.
|
||||
* Called after successful login to ensure all models are available.
|
||||
*/
|
||||
async function enableAllGitHubCopilotModels(
|
||||
token: string,
|
||||
enterpriseDomain?: string,
|
||||
onProgress?: (model: string, success: boolean) => void,
|
||||
): Promise<void> {
|
||||
const models = getModels("github-copilot");
|
||||
await Promise.all(
|
||||
models.map(async (model) => {
|
||||
const success = await enableGitHubCopilotModel(token, model.id, enterpriseDomain);
|
||||
onProgress?.(model.id, success);
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Login with GitHub Copilot OAuth (device code flow)
|
||||
*
|
||||
* @param options.onAuth - Callback with URL and optional instructions (user code)
|
||||
* @param options.onPrompt - Callback to prompt user for input
|
||||
* @param options.onProgress - Optional progress callback
|
||||
* @param options.signal - Optional AbortSignal for cancellation
|
||||
*/
|
||||
export async function loginGitHubCopilot(options: {
|
||||
onAuth: (url: string, instructions?: string) => void;
|
||||
onPrompt: (prompt: { message: string; placeholder?: string; allowEmpty?: boolean }) => Promise<string>;
|
||||
onProgress?: (message: string) => void;
|
||||
signal?: AbortSignal;
|
||||
}): Promise<OAuthCredentials> {
|
||||
const input = await options.onPrompt({
|
||||
message: "GitHub Enterprise URL/domain (blank for github.com)",
|
||||
placeholder: "company.ghe.com",
|
||||
allowEmpty: true,
|
||||
});
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Login cancelled");
|
||||
}
|
||||
|
||||
const trimmed = input.trim();
|
||||
const enterpriseDomain = normalizeDomain(input);
|
||||
if (trimmed && !enterpriseDomain) {
|
||||
throw new Error("Invalid GitHub Enterprise URL/domain");
|
||||
}
|
||||
const domain = enterpriseDomain || "github.com";
|
||||
|
||||
const device = await startDeviceFlow(domain);
|
||||
options.onAuth(device.verification_uri, `Enter code: ${device.user_code}`);
|
||||
|
||||
const githubAccessToken = await pollForGitHubAccessToken(
|
||||
domain,
|
||||
device.device_code,
|
||||
device.interval,
|
||||
device.expires_in,
|
||||
options.signal,
|
||||
);
|
||||
const credentials = await refreshGitHubCopilotToken(githubAccessToken, enterpriseDomain ?? undefined);
|
||||
|
||||
// Enable all models after successful login
|
||||
options.onProgress?.("Enabling models...");
|
||||
await enableAllGitHubCopilotModels(credentials.access, enterpriseDomain ?? undefined);
|
||||
return credentials;
|
||||
}
|
||||
|
||||
export const githubCopilotOAuthProvider: OAuthProviderInterface = {
|
||||
id: "github-copilot",
|
||||
name: "GitHub Copilot",
|
||||
|
||||
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
return loginGitHubCopilot({
|
||||
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
|
||||
onPrompt: callbacks.onPrompt,
|
||||
onProgress: callbacks.onProgress,
|
||||
signal: callbacks.signal,
|
||||
});
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
const creds = credentials as CopilotCredentials;
|
||||
return refreshGitHubCopilotToken(creds.refresh, creds.enterpriseUrl);
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
return credentials.access;
|
||||
},
|
||||
|
||||
modifyModels(models: Model<Api>[], credentials: OAuthCredentials): Model<Api>[] {
|
||||
const creds = credentials as CopilotCredentials;
|
||||
const domain = creds.enterpriseUrl ? (normalizeDomain(creds.enterpriseUrl) ?? undefined) : undefined;
|
||||
const baseUrl = getGitHubCopilotBaseUrl(creds.access, domain);
|
||||
return models.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m));
|
||||
},
|
||||
};
|
||||
453
packages/ai/src/utils/oauth/google-antigravity.ts
Normal file
453
packages/ai/src/utils/oauth/google-antigravity.ts
Normal file
|
|
@ -0,0 +1,453 @@
|
|||
/**
|
||||
* Antigravity OAuth flow (Gemini 3, Claude, GPT-OSS via Google Cloud)
|
||||
* Uses different OAuth credentials than google-gemini-cli for access to additional models.
|
||||
*
|
||||
* NOTE: This module uses Node.js http.createServer for the OAuth callback.
|
||||
* It is only intended for CLI use, not browser environments.
|
||||
*/
|
||||
|
||||
import type { Server } from "node:http";
|
||||
import { oauthErrorHtml, oauthSuccessHtml } from "./oauth-page.js";
|
||||
import { generatePKCE } from "./pkce.js";
|
||||
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
|
||||
|
||||
type AntigravityCredentials = OAuthCredentials & {
|
||||
projectId: string;
|
||||
};
|
||||
|
||||
let _createServer: typeof import("node:http").createServer | null = null;
|
||||
let _httpImportPromise: Promise<void> | null = null;
|
||||
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
|
||||
_httpImportPromise = import("node:http").then((m) => {
|
||||
_createServer = m.createServer;
|
||||
});
|
||||
}
|
||||
|
||||
// Antigravity OAuth credentials (different from Gemini CLI)
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode(
|
||||
"MTA3MTAwNjA2MDU5MS10bWhzc2luMmgyMWxjcmUyMzV2dG9sb2poNGc0MDNlcC5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbQ==",
|
||||
);
|
||||
const CLIENT_SECRET = decode("R09DU1BYLUs1OEZXUjQ4NkxkTEoxbUxCOHNYQzR6NnFEQWY=");
|
||||
const REDIRECT_URI = "http://localhost:51121/oauth-callback";
|
||||
|
||||
// Antigravity requires additional scopes
|
||||
const SCOPES = [
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/cclog",
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs",
|
||||
];
|
||||
|
||||
const AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth";
|
||||
const TOKEN_URL = "https://oauth2.googleapis.com/token";
|
||||
|
||||
// Fallback project ID when discovery fails
|
||||
const DEFAULT_PROJECT_ID = "rising-fact-p41fc";
|
||||
|
||||
type CallbackServerInfo = {
|
||||
server: Server;
|
||||
cancelWait: () => void;
|
||||
waitForCode: () => Promise<{ code: string; state: string } | null>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Start a local HTTP server to receive the OAuth callback
|
||||
*/
|
||||
async function getNodeCreateServer(): Promise<typeof import("node:http").createServer> {
|
||||
if (_createServer) return _createServer;
|
||||
if (_httpImportPromise) {
|
||||
await _httpImportPromise;
|
||||
}
|
||||
if (_createServer) return _createServer;
|
||||
throw new Error("Antigravity OAuth is only available in Node.js environments");
|
||||
}
|
||||
|
||||
async function startCallbackServer(): Promise<CallbackServerInfo> {
|
||||
const createServer = await getNodeCreateServer();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
let settleWait: ((value: { code: string; state: string } | null) => void) | undefined;
|
||||
const waitForCodePromise = new Promise<{ code: string; state: string } | null>((resolveWait) => {
|
||||
let settled = false;
|
||||
settleWait = (value) => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
resolveWait(value);
|
||||
};
|
||||
});
|
||||
|
||||
const server = createServer((req, res) => {
|
||||
const url = new URL(req.url || "", `http://localhost:51121`);
|
||||
|
||||
if (url.pathname === "/oauth-callback") {
|
||||
const code = url.searchParams.get("code");
|
||||
const state = url.searchParams.get("state");
|
||||
const error = url.searchParams.get("error");
|
||||
|
||||
if (error) {
|
||||
res.writeHead(400, { "Content-Type": "text/html; charset=utf-8" });
|
||||
res.end(oauthErrorHtml("Google authentication did not complete.", `Error: ${error}`));
|
||||
return;
|
||||
}
|
||||
|
||||
if (code && state) {
|
||||
res.writeHead(200, { "Content-Type": "text/html; charset=utf-8" });
|
||||
res.end(oauthSuccessHtml("Google authentication completed. You can close this window."));
|
||||
settleWait?.({ code, state });
|
||||
} else {
|
||||
res.writeHead(400, { "Content-Type": "text/html; charset=utf-8" });
|
||||
res.end(oauthErrorHtml("Missing code or state parameter."));
|
||||
}
|
||||
} else {
|
||||
res.writeHead(404, { "Content-Type": "text/html; charset=utf-8" });
|
||||
res.end(oauthErrorHtml("Callback route not found."));
|
||||
}
|
||||
});
|
||||
|
||||
server.on("error", (err) => {
|
||||
reject(err);
|
||||
});
|
||||
|
||||
server.listen(51121, "127.0.0.1", () => {
|
||||
resolve({
|
||||
server,
|
||||
cancelWait: () => {
|
||||
settleWait?.(null);
|
||||
},
|
||||
waitForCode: () => waitForCodePromise,
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse redirect URL to extract code and state
|
||||
*/
|
||||
function parseRedirectUrl(input: string): { code?: string; state?: string } {
|
||||
const value = input.trim();
|
||||
if (!value) return {};
|
||||
|
||||
try {
|
||||
const url = new URL(value);
|
||||
return {
|
||||
code: url.searchParams.get("code") ?? undefined,
|
||||
state: url.searchParams.get("state") ?? undefined,
|
||||
};
|
||||
} catch {
|
||||
// Not a URL, return empty
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
interface LoadCodeAssistPayload {
|
||||
cloudaicompanionProject?: string | { id?: string };
|
||||
currentTier?: { id?: string };
|
||||
allowedTiers?: Array<{ id?: string; isDefault?: boolean }>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover or provision a project for the user
|
||||
*/
|
||||
async function discoverProject(accessToken: string, onProgress?: (message: string) => void): Promise<string> {
|
||||
const headers = {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "google-api-nodejs-client/9.15.1",
|
||||
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
|
||||
"Client-Metadata": JSON.stringify({
|
||||
ideType: "IDE_UNSPECIFIED",
|
||||
platform: "PLATFORM_UNSPECIFIED",
|
||||
pluginType: "GEMINI",
|
||||
}),
|
||||
};
|
||||
|
||||
// Try endpoints in order: prod first, then sandbox
|
||||
const endpoints = ["https://cloudcode-pa.googleapis.com", "https://daily-cloudcode-pa.sandbox.googleapis.com"];
|
||||
|
||||
onProgress?.("Checking for existing project...");
|
||||
|
||||
for (const endpoint of endpoints) {
|
||||
try {
|
||||
const loadResponse = await fetch(`${endpoint}/v1internal:loadCodeAssist`, {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
metadata: {
|
||||
ideType: "IDE_UNSPECIFIED",
|
||||
platform: "PLATFORM_UNSPECIFIED",
|
||||
pluginType: "GEMINI",
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
if (loadResponse.ok) {
|
||||
const data = (await loadResponse.json()) as LoadCodeAssistPayload;
|
||||
|
||||
// Handle both string and object formats
|
||||
if (typeof data.cloudaicompanionProject === "string" && data.cloudaicompanionProject) {
|
||||
return data.cloudaicompanionProject;
|
||||
}
|
||||
if (
|
||||
data.cloudaicompanionProject &&
|
||||
typeof data.cloudaicompanionProject === "object" &&
|
||||
data.cloudaicompanionProject.id
|
||||
) {
|
||||
return data.cloudaicompanionProject.id;
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Try next endpoint
|
||||
}
|
||||
}
|
||||
|
||||
// Use fallback project ID
|
||||
onProgress?.("Using default project...");
|
||||
return DEFAULT_PROJECT_ID;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get user email from the access token
|
||||
*/
|
||||
async function getUserEmail(accessToken: string): Promise<string | undefined> {
|
||||
try {
|
||||
const response = await fetch("https://www.googleapis.com/oauth2/v1/userinfo?alt=json", {
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = (await response.json()) as { email?: string };
|
||||
return data.email;
|
||||
}
|
||||
} catch {
|
||||
// Ignore errors, email is optional
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh Antigravity token
|
||||
*/
|
||||
export async function refreshAntigravityToken(refreshToken: string, projectId: string): Promise<OAuthCredentials> {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
client_secret: CLIENT_SECRET,
|
||||
refresh_token: refreshToken,
|
||||
grant_type: "refresh_token",
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`Antigravity token refresh failed: ${error}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as {
|
||||
access_token: string;
|
||||
expires_in: number;
|
||||
refresh_token?: string;
|
||||
};
|
||||
|
||||
return {
|
||||
refresh: data.refresh_token || refreshToken,
|
||||
access: data.access_token,
|
||||
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
|
||||
projectId,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Login with Antigravity OAuth
|
||||
*
|
||||
* @param onAuth - Callback with URL and optional instructions
|
||||
* @param onProgress - Optional progress callback
|
||||
* @param onManualCodeInput - Optional promise that resolves with user-pasted redirect URL.
|
||||
* Races with browser callback - whichever completes first wins.
|
||||
*/
|
||||
export async function loginAntigravity(
|
||||
onAuth: (info: { url: string; instructions?: string }) => void,
|
||||
onProgress?: (message: string) => void,
|
||||
onManualCodeInput?: () => Promise<string>,
|
||||
): Promise<OAuthCredentials> {
|
||||
const { verifier, challenge } = await generatePKCE();
|
||||
|
||||
// Start local server for callback
|
||||
onProgress?.("Starting local server for OAuth callback...");
|
||||
const server = await startCallbackServer();
|
||||
|
||||
let code: string | undefined;
|
||||
|
||||
try {
|
||||
// Build authorization URL
|
||||
const authParams = new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
response_type: "code",
|
||||
redirect_uri: REDIRECT_URI,
|
||||
scope: SCOPES.join(" "),
|
||||
code_challenge: challenge,
|
||||
code_challenge_method: "S256",
|
||||
state: verifier,
|
||||
access_type: "offline",
|
||||
prompt: "consent",
|
||||
});
|
||||
|
||||
const authUrl = `${AUTH_URL}?${authParams.toString()}`;
|
||||
|
||||
// Notify caller with URL to open
|
||||
onAuth({
|
||||
url: authUrl,
|
||||
instructions: "Complete the sign-in in your browser.",
|
||||
});
|
||||
|
||||
// Wait for the callback, racing with manual input if provided
|
||||
onProgress?.("Waiting for OAuth callback...");
|
||||
|
||||
if (onManualCodeInput) {
|
||||
// Race between browser callback and manual input
|
||||
let manualInput: string | undefined;
|
||||
let manualError: Error | undefined;
|
||||
const manualPromise = onManualCodeInput()
|
||||
.then((input) => {
|
||||
manualInput = input;
|
||||
server.cancelWait();
|
||||
})
|
||||
.catch((err) => {
|
||||
manualError = err instanceof Error ? err : new Error(String(err));
|
||||
server.cancelWait();
|
||||
});
|
||||
|
||||
const result = await server.waitForCode();
|
||||
|
||||
// If manual input was cancelled, throw that error
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
|
||||
if (result?.code) {
|
||||
// Browser callback won - verify state
|
||||
if (result.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = result.code;
|
||||
} else if (manualInput) {
|
||||
// Manual input won
|
||||
const parsed = parseRedirectUrl(manualInput);
|
||||
if (parsed.state && parsed.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
|
||||
// If still no code, wait for manual promise and try that
|
||||
if (!code) {
|
||||
await manualPromise;
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
if (manualInput) {
|
||||
const parsed = parseRedirectUrl(manualInput);
|
||||
if (parsed.state && parsed.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Original flow: just wait for callback
|
||||
const result = await server.waitForCode();
|
||||
if (result?.code) {
|
||||
if (result.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = result.code;
|
||||
}
|
||||
}
|
||||
|
||||
if (!code) {
|
||||
throw new Error("No authorization code received");
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
onProgress?.("Exchanging authorization code for tokens...");
|
||||
const tokenResponse = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
body: new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
client_secret: CLIENT_SECRET,
|
||||
code,
|
||||
grant_type: "authorization_code",
|
||||
redirect_uri: REDIRECT_URI,
|
||||
code_verifier: verifier,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!tokenResponse.ok) {
|
||||
const error = await tokenResponse.text();
|
||||
throw new Error(`Token exchange failed: ${error}`);
|
||||
}
|
||||
|
||||
const tokenData = (await tokenResponse.json()) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
expires_in: number;
|
||||
};
|
||||
|
||||
if (!tokenData.refresh_token) {
|
||||
throw new Error("No refresh token received. Please try again.");
|
||||
}
|
||||
|
||||
// Get user email
|
||||
onProgress?.("Getting user info...");
|
||||
const email = await getUserEmail(tokenData.access_token);
|
||||
|
||||
// Discover project
|
||||
const projectId = await discoverProject(tokenData.access_token, onProgress);
|
||||
|
||||
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
|
||||
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
|
||||
|
||||
const credentials: OAuthCredentials = {
|
||||
refresh: tokenData.refresh_token,
|
||||
access: tokenData.access_token,
|
||||
expires: expiresAt,
|
||||
projectId,
|
||||
email,
|
||||
};
|
||||
|
||||
return credentials;
|
||||
} finally {
|
||||
server.server.close();
|
||||
}
|
||||
}
|
||||
|
||||
export const antigravityOAuthProvider: OAuthProviderInterface = {
|
||||
id: "google-antigravity",
|
||||
name: "Antigravity (Gemini 3, Claude, GPT-OSS)",
|
||||
usesCallbackServer: true,
|
||||
|
||||
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
return loginAntigravity(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
const creds = credentials as AntigravityCredentials;
|
||||
if (!creds.projectId) {
|
||||
throw new Error("Antigravity credentials missing projectId");
|
||||
}
|
||||
return refreshAntigravityToken(creds.refresh, creds.projectId);
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
const creds = credentials as AntigravityCredentials;
|
||||
return JSON.stringify({ token: creds.access, projectId: creds.projectId });
|
||||
},
|
||||
};
|
||||
595
packages/ai/src/utils/oauth/google-gemini-cli.ts
Normal file
595
packages/ai/src/utils/oauth/google-gemini-cli.ts
Normal file
|
|
@ -0,0 +1,595 @@
|
|||
/**
|
||||
* Gemini CLI OAuth flow (Google Cloud Code Assist)
|
||||
* Standard Gemini models only (gemini-2.0-flash, gemini-2.5-*)
|
||||
*
|
||||
* NOTE: This module uses Node.js http.createServer for the OAuth callback.
|
||||
* It is only intended for CLI use, not browser environments.
|
||||
*/
|
||||
|
||||
import type { Server } from "node:http";
|
||||
import { oauthErrorHtml, oauthSuccessHtml } from "./oauth-page.js";
|
||||
import { generatePKCE } from "./pkce.js";
|
||||
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
|
||||
|
||||
type GeminiCredentials = OAuthCredentials & {
|
||||
projectId: string;
|
||||
};
|
||||
|
||||
let _createServer: typeof import("node:http").createServer | null = null;
|
||||
let _httpImportPromise: Promise<void> | null = null;
|
||||
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
|
||||
_httpImportPromise = import("node:http").then((m) => {
|
||||
_createServer = m.createServer;
|
||||
});
|
||||
}
|
||||
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode(
|
||||
"NjgxMjU1ODA5Mzk1LW9vOGZ0Mm9wcmRybnA5ZTNhcWY2YXYzaG1kaWIxMzVqLmFwcHMuZ29vZ2xldXNlcmNvbnRlbnQuY29t",
|
||||
);
|
||||
const CLIENT_SECRET = decode("R09DU1BYLTR1SGdNUG0tMW83U2stZ2VWNkN1NWNsWEZzeGw=");
|
||||
const REDIRECT_URI = "http://localhost:8085/oauth2callback";
|
||||
const SCOPES = [
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
];
|
||||
const AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth";
|
||||
const TOKEN_URL = "https://oauth2.googleapis.com/token";
|
||||
const CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com";
|
||||
|
||||
type CallbackServerInfo = {
|
||||
server: Server;
|
||||
cancelWait: () => void;
|
||||
waitForCode: () => Promise<{ code: string; state: string } | null>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Start a local HTTP server to receive the OAuth callback
|
||||
*/
|
||||
async function getNodeCreateServer(): Promise<typeof import("node:http").createServer> {
|
||||
if (_createServer) return _createServer;
|
||||
if (_httpImportPromise) {
|
||||
await _httpImportPromise;
|
||||
}
|
||||
if (_createServer) return _createServer;
|
||||
throw new Error("Gemini CLI OAuth is only available in Node.js environments");
|
||||
}
|
||||
|
||||
async function startCallbackServer(): Promise<CallbackServerInfo> {
|
||||
const createServer = await getNodeCreateServer();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
let settleWait: ((value: { code: string; state: string } | null) => void) | undefined;
|
||||
const waitForCodePromise = new Promise<{ code: string; state: string } | null>((resolveWait) => {
|
||||
let settled = false;
|
||||
settleWait = (value) => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
resolveWait(value);
|
||||
};
|
||||
});
|
||||
|
||||
const server = createServer((req, res) => {
|
||||
const url = new URL(req.url || "", `http://localhost:8085`);
|
||||
|
||||
if (url.pathname === "/oauth2callback") {
|
||||
const code = url.searchParams.get("code");
|
||||
const state = url.searchParams.get("state");
|
||||
const error = url.searchParams.get("error");
|
||||
|
||||
if (error) {
|
||||
res.writeHead(400, { "Content-Type": "text/html; charset=utf-8" });
|
||||
res.end(oauthErrorHtml("Google authentication did not complete.", `Error: ${error}`));
|
||||
return;
|
||||
}
|
||||
|
||||
if (code && state) {
|
||||
res.writeHead(200, { "Content-Type": "text/html; charset=utf-8" });
|
||||
res.end(oauthSuccessHtml("Google authentication completed. You can close this window."));
|
||||
settleWait?.({ code, state });
|
||||
} else {
|
||||
res.writeHead(400, { "Content-Type": "text/html; charset=utf-8" });
|
||||
res.end(oauthErrorHtml("Missing code or state parameter."));
|
||||
}
|
||||
} else {
|
||||
res.writeHead(404, { "Content-Type": "text/html; charset=utf-8" });
|
||||
res.end(oauthErrorHtml("Callback route not found."));
|
||||
}
|
||||
});
|
||||
|
||||
server.on("error", (err) => {
|
||||
reject(err);
|
||||
});
|
||||
|
||||
server.listen(8085, "127.0.0.1", () => {
|
||||
resolve({
|
||||
server,
|
||||
cancelWait: () => {
|
||||
settleWait?.(null);
|
||||
},
|
||||
waitForCode: () => waitForCodePromise,
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse redirect URL to extract code and state
|
||||
*/
|
||||
function parseRedirectUrl(input: string): { code?: string; state?: string } {
|
||||
const value = input.trim();
|
||||
if (!value) return {};
|
||||
|
||||
try {
|
||||
const url = new URL(value);
|
||||
return {
|
||||
code: url.searchParams.get("code") ?? undefined,
|
||||
state: url.searchParams.get("state") ?? undefined,
|
||||
};
|
||||
} catch {
|
||||
// Not a URL, return empty
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
interface LoadCodeAssistPayload {
|
||||
cloudaicompanionProject?: string;
|
||||
currentTier?: { id?: string };
|
||||
allowedTiers?: Array<{ id?: string; isDefault?: boolean }>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Long-running operation response from onboardUser
|
||||
*/
|
||||
interface LongRunningOperationResponse {
|
||||
name?: string;
|
||||
done?: boolean;
|
||||
response?: {
|
||||
cloudaicompanionProject?: { id?: string };
|
||||
};
|
||||
}
|
||||
|
||||
// Tier IDs as used by the Cloud Code API
|
||||
const TIER_FREE = "free-tier";
|
||||
const TIER_LEGACY = "legacy-tier";
|
||||
const TIER_STANDARD = "standard-tier";
|
||||
|
||||
interface GoogleRpcErrorResponse {
|
||||
error?: {
|
||||
details?: Array<{ reason?: string }>;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait helper for onboarding retries
|
||||
*/
|
||||
function wait(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get default tier from allowed tiers
|
||||
*/
|
||||
function getDefaultTier(allowedTiers?: Array<{ id?: string; isDefault?: boolean }>): { id?: string } {
|
||||
if (!allowedTiers || allowedTiers.length === 0) return { id: TIER_LEGACY };
|
||||
const defaultTier = allowedTiers.find((t) => t.isDefault);
|
||||
return defaultTier ?? { id: TIER_LEGACY };
|
||||
}
|
||||
|
||||
function isVpcScAffectedUser(payload: unknown): boolean {
|
||||
if (!payload || typeof payload !== "object") return false;
|
||||
if (!("error" in payload)) return false;
|
||||
const error = (payload as GoogleRpcErrorResponse).error;
|
||||
if (!error?.details || !Array.isArray(error.details)) return false;
|
||||
return error.details.some((detail) => detail.reason === "SECURITY_POLICY_VIOLATED");
|
||||
}
|
||||
|
||||
/**
|
||||
* Poll a long-running operation until completion
|
||||
*/
|
||||
async function pollOperation(
|
||||
operationName: string,
|
||||
headers: Record<string, string>,
|
||||
onProgress?: (message: string) => void,
|
||||
): Promise<LongRunningOperationResponse> {
|
||||
let attempt = 0;
|
||||
while (true) {
|
||||
if (attempt > 0) {
|
||||
onProgress?.(`Waiting for project provisioning (attempt ${attempt + 1})...`);
|
||||
await wait(5000);
|
||||
}
|
||||
|
||||
const response = await fetch(`${CODE_ASSIST_ENDPOINT}/v1internal/${operationName}`, {
|
||||
method: "GET",
|
||||
headers,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to poll operation: ${response.status} ${response.statusText}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as LongRunningOperationResponse;
|
||||
if (data.done) {
|
||||
return data;
|
||||
}
|
||||
|
||||
attempt += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover or provision a Google Cloud project for the user
|
||||
*/
|
||||
async function discoverProject(accessToken: string, onProgress?: (message: string) => void): Promise<string> {
|
||||
// Check for user-provided project ID via environment variable
|
||||
const envProjectId = process.env.GOOGLE_CLOUD_PROJECT || process.env.GOOGLE_CLOUD_PROJECT_ID;
|
||||
|
||||
const headers = {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "google-api-nodejs-client/9.15.1",
|
||||
"X-Goog-Api-Client": "gl-node/22.17.0",
|
||||
};
|
||||
|
||||
// Try to load existing project via loadCodeAssist
|
||||
onProgress?.("Checking for existing Cloud Code Assist project...");
|
||||
const loadResponse = await fetch(`${CODE_ASSIST_ENDPOINT}/v1internal:loadCodeAssist`, {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
cloudaicompanionProject: envProjectId,
|
||||
metadata: {
|
||||
ideType: "IDE_UNSPECIFIED",
|
||||
platform: "PLATFORM_UNSPECIFIED",
|
||||
pluginType: "GEMINI",
|
||||
duetProject: envProjectId,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
let data: LoadCodeAssistPayload;
|
||||
|
||||
if (!loadResponse.ok) {
|
||||
let errorPayload: unknown;
|
||||
try {
|
||||
errorPayload = await loadResponse.clone().json();
|
||||
} catch {
|
||||
errorPayload = undefined;
|
||||
}
|
||||
|
||||
if (isVpcScAffectedUser(errorPayload)) {
|
||||
data = { currentTier: { id: TIER_STANDARD } };
|
||||
} else {
|
||||
const errorText = await loadResponse.text();
|
||||
throw new Error(`loadCodeAssist failed: ${loadResponse.status} ${loadResponse.statusText}: ${errorText}`);
|
||||
}
|
||||
} else {
|
||||
data = (await loadResponse.json()) as LoadCodeAssistPayload;
|
||||
}
|
||||
|
||||
// If user already has a current tier and project, use it
|
||||
if (data.currentTier) {
|
||||
if (data.cloudaicompanionProject) {
|
||||
return data.cloudaicompanionProject;
|
||||
}
|
||||
// User has a tier but no managed project - they need to provide one via env var
|
||||
if (envProjectId) {
|
||||
return envProjectId;
|
||||
}
|
||||
throw new Error(
|
||||
"This account requires setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
|
||||
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
|
||||
);
|
||||
}
|
||||
|
||||
// User needs to be onboarded - get the default tier
|
||||
const tier = getDefaultTier(data.allowedTiers);
|
||||
const tierId = tier?.id ?? TIER_FREE;
|
||||
|
||||
if (tierId !== TIER_FREE && !envProjectId) {
|
||||
throw new Error(
|
||||
"This account requires setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
|
||||
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
|
||||
);
|
||||
}
|
||||
|
||||
onProgress?.("Provisioning Cloud Code Assist project (this may take a moment)...");
|
||||
|
||||
// Build onboard request - for free tier, don't include project ID (Google provisions one)
|
||||
// For other tiers, include the user's project ID if available
|
||||
const onboardBody: Record<string, unknown> = {
|
||||
tierId,
|
||||
metadata: {
|
||||
ideType: "IDE_UNSPECIFIED",
|
||||
platform: "PLATFORM_UNSPECIFIED",
|
||||
pluginType: "GEMINI",
|
||||
},
|
||||
};
|
||||
|
||||
if (tierId !== TIER_FREE && envProjectId) {
|
||||
onboardBody.cloudaicompanionProject = envProjectId;
|
||||
(onboardBody.metadata as Record<string, unknown>).duetProject = envProjectId;
|
||||
}
|
||||
|
||||
// Start onboarding - this returns a long-running operation
|
||||
const onboardResponse = await fetch(`${CODE_ASSIST_ENDPOINT}/v1internal:onboardUser`, {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify(onboardBody),
|
||||
});
|
||||
|
||||
if (!onboardResponse.ok) {
|
||||
const errorText = await onboardResponse.text();
|
||||
throw new Error(`onboardUser failed: ${onboardResponse.status} ${onboardResponse.statusText}: ${errorText}`);
|
||||
}
|
||||
|
||||
let lroData = (await onboardResponse.json()) as LongRunningOperationResponse;
|
||||
|
||||
// If the operation isn't done yet, poll until completion
|
||||
if (!lroData.done && lroData.name) {
|
||||
lroData = await pollOperation(lroData.name, headers, onProgress);
|
||||
}
|
||||
|
||||
// Try to get project ID from the response
|
||||
const projectId = lroData.response?.cloudaicompanionProject?.id;
|
||||
if (projectId) {
|
||||
return projectId;
|
||||
}
|
||||
|
||||
// If no project ID from onboarding, fall back to env var
|
||||
if (envProjectId) {
|
||||
return envProjectId;
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
"Could not discover or provision a Google Cloud project. " +
|
||||
"Try setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
|
||||
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get user email from the access token
|
||||
*/
|
||||
async function getUserEmail(accessToken: string): Promise<string | undefined> {
|
||||
try {
|
||||
const response = await fetch("https://www.googleapis.com/oauth2/v1/userinfo?alt=json", {
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = (await response.json()) as { email?: string };
|
||||
return data.email;
|
||||
}
|
||||
} catch {
|
||||
// Ignore errors, email is optional
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh Google Cloud Code Assist token
|
||||
*/
|
||||
export async function refreshGoogleCloudToken(refreshToken: string, projectId: string): Promise<OAuthCredentials> {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
client_secret: CLIENT_SECRET,
|
||||
refresh_token: refreshToken,
|
||||
grant_type: "refresh_token",
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`Google Cloud token refresh failed: ${error}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as {
|
||||
access_token: string;
|
||||
expires_in: number;
|
||||
refresh_token?: string;
|
||||
};
|
||||
|
||||
return {
|
||||
refresh: data.refresh_token || refreshToken,
|
||||
access: data.access_token,
|
||||
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
|
||||
projectId,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Login with Gemini CLI (Google Cloud Code Assist) OAuth
|
||||
*
|
||||
* @param onAuth - Callback with URL and optional instructions
|
||||
* @param onProgress - Optional progress callback
|
||||
* @param onManualCodeInput - Optional promise that resolves with user-pasted redirect URL.
|
||||
* Races with browser callback - whichever completes first wins.
|
||||
*/
|
||||
export async function loginGeminiCli(
|
||||
onAuth: (info: { url: string; instructions?: string }) => void,
|
||||
onProgress?: (message: string) => void,
|
||||
onManualCodeInput?: () => Promise<string>,
|
||||
): Promise<OAuthCredentials> {
|
||||
const { verifier, challenge } = await generatePKCE();
|
||||
|
||||
// Start local server for callback
|
||||
onProgress?.("Starting local server for OAuth callback...");
|
||||
const server = await startCallbackServer();
|
||||
|
||||
let code: string | undefined;
|
||||
|
||||
try {
|
||||
// Build authorization URL
|
||||
const authParams = new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
response_type: "code",
|
||||
redirect_uri: REDIRECT_URI,
|
||||
scope: SCOPES.join(" "),
|
||||
code_challenge: challenge,
|
||||
code_challenge_method: "S256",
|
||||
state: verifier,
|
||||
access_type: "offline",
|
||||
prompt: "consent",
|
||||
});
|
||||
|
||||
const authUrl = `${AUTH_URL}?${authParams.toString()}`;
|
||||
|
||||
// Notify caller with URL to open
|
||||
onAuth({
|
||||
url: authUrl,
|
||||
instructions: "Complete the sign-in in your browser.",
|
||||
});
|
||||
|
||||
// Wait for the callback, racing with manual input if provided
|
||||
onProgress?.("Waiting for OAuth callback...");
|
||||
|
||||
if (onManualCodeInput) {
|
||||
// Race between browser callback and manual input
|
||||
let manualInput: string | undefined;
|
||||
let manualError: Error | undefined;
|
||||
const manualPromise = onManualCodeInput()
|
||||
.then((input) => {
|
||||
manualInput = input;
|
||||
server.cancelWait();
|
||||
})
|
||||
.catch((err) => {
|
||||
manualError = err instanceof Error ? err : new Error(String(err));
|
||||
server.cancelWait();
|
||||
});
|
||||
|
||||
const result = await server.waitForCode();
|
||||
|
||||
// If manual input was cancelled, throw that error
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
|
||||
if (result?.code) {
|
||||
// Browser callback won - verify state
|
||||
if (result.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = result.code;
|
||||
} else if (manualInput) {
|
||||
// Manual input won
|
||||
const parsed = parseRedirectUrl(manualInput);
|
||||
if (parsed.state && parsed.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
|
||||
// If still no code, wait for manual promise and try that
|
||||
if (!code) {
|
||||
await manualPromise;
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
if (manualInput) {
|
||||
const parsed = parseRedirectUrl(manualInput);
|
||||
if (parsed.state && parsed.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Original flow: just wait for callback
|
||||
const result = await server.waitForCode();
|
||||
if (result?.code) {
|
||||
if (result.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = result.code;
|
||||
}
|
||||
}
|
||||
|
||||
if (!code) {
|
||||
throw new Error("No authorization code received");
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
onProgress?.("Exchanging authorization code for tokens...");
|
||||
const tokenResponse = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
body: new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
client_secret: CLIENT_SECRET,
|
||||
code,
|
||||
grant_type: "authorization_code",
|
||||
redirect_uri: REDIRECT_URI,
|
||||
code_verifier: verifier,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!tokenResponse.ok) {
|
||||
const error = await tokenResponse.text();
|
||||
throw new Error(`Token exchange failed: ${error}`);
|
||||
}
|
||||
|
||||
const tokenData = (await tokenResponse.json()) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
expires_in: number;
|
||||
};
|
||||
|
||||
if (!tokenData.refresh_token) {
|
||||
throw new Error("No refresh token received. Please try again.");
|
||||
}
|
||||
|
||||
// Get user email
|
||||
onProgress?.("Getting user info...");
|
||||
const email = await getUserEmail(tokenData.access_token);
|
||||
|
||||
// Discover project
|
||||
const projectId = await discoverProject(tokenData.access_token, onProgress);
|
||||
|
||||
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
|
||||
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
|
||||
|
||||
const credentials: OAuthCredentials = {
|
||||
refresh: tokenData.refresh_token,
|
||||
access: tokenData.access_token,
|
||||
expires: expiresAt,
|
||||
projectId,
|
||||
email,
|
||||
};
|
||||
|
||||
return credentials;
|
||||
} finally {
|
||||
server.server.close();
|
||||
}
|
||||
}
|
||||
|
||||
export const geminiCliOAuthProvider: OAuthProviderInterface = {
|
||||
id: "google-gemini-cli",
|
||||
name: "Google Cloud Code Assist (Gemini CLI)",
|
||||
usesCallbackServer: true,
|
||||
|
||||
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
return loginGeminiCli(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
const creds = credentials as GeminiCredentials;
|
||||
if (!creds.projectId) {
|
||||
throw new Error("Google Cloud credentials missing projectId");
|
||||
}
|
||||
return refreshGoogleCloudToken(creds.refresh, creds.projectId);
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
const creds = credentials as GeminiCredentials;
|
||||
return JSON.stringify({ token: creds.access, projectId: creds.projectId });
|
||||
},
|
||||
};
|
||||
162
packages/ai/src/utils/oauth/index.ts
Normal file
162
packages/ai/src/utils/oauth/index.ts
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
/**
|
||||
* OAuth credential management for AI providers.
|
||||
*
|
||||
* This module handles login, token refresh, and credential storage
|
||||
* for OAuth-based providers:
|
||||
* - Anthropic (Claude Pro/Max)
|
||||
* - GitHub Copilot
|
||||
* - Google Cloud Code Assist (Gemini CLI)
|
||||
* - Antigravity (Gemini 3, Claude, GPT-OSS via Google Cloud)
|
||||
*/
|
||||
|
||||
// Anthropic
|
||||
export { anthropicOAuthProvider, loginAnthropic, refreshAnthropicToken } from "./anthropic.js";
|
||||
// GitHub Copilot
|
||||
export {
|
||||
getGitHubCopilotBaseUrl,
|
||||
githubCopilotOAuthProvider,
|
||||
loginGitHubCopilot,
|
||||
normalizeDomain,
|
||||
refreshGitHubCopilotToken,
|
||||
} from "./github-copilot.js";
|
||||
// Google Antigravity
|
||||
export { antigravityOAuthProvider, loginAntigravity, refreshAntigravityToken } from "./google-antigravity.js";
|
||||
// Google Gemini CLI
|
||||
export { geminiCliOAuthProvider, loginGeminiCli, refreshGoogleCloudToken } from "./google-gemini-cli.js";
|
||||
// OpenAI Codex (ChatGPT OAuth)
|
||||
export { loginOpenAICodex, openaiCodexOAuthProvider, refreshOpenAICodexToken } from "./openai-codex.js";
|
||||
|
||||
export * from "./types.js";
|
||||
|
||||
// ============================================================================
|
||||
// Provider Registry
|
||||
// ============================================================================
|
||||
|
||||
import { anthropicOAuthProvider } from "./anthropic.js";
|
||||
import { githubCopilotOAuthProvider } from "./github-copilot.js";
|
||||
import { antigravityOAuthProvider } from "./google-antigravity.js";
|
||||
import { geminiCliOAuthProvider } from "./google-gemini-cli.js";
|
||||
import { openaiCodexOAuthProvider } from "./openai-codex.js";
|
||||
import type { OAuthCredentials, OAuthProviderId, OAuthProviderInfo, OAuthProviderInterface } from "./types.js";
|
||||
|
||||
const BUILT_IN_OAUTH_PROVIDERS: OAuthProviderInterface[] = [
|
||||
anthropicOAuthProvider,
|
||||
githubCopilotOAuthProvider,
|
||||
geminiCliOAuthProvider,
|
||||
antigravityOAuthProvider,
|
||||
openaiCodexOAuthProvider,
|
||||
];
|
||||
|
||||
const oauthProviderRegistry = new Map<string, OAuthProviderInterface>(
|
||||
BUILT_IN_OAUTH_PROVIDERS.map((provider) => [provider.id, provider]),
|
||||
);
|
||||
|
||||
/**
|
||||
* Get an OAuth provider by ID
|
||||
*/
|
||||
export function getOAuthProvider(id: OAuthProviderId): OAuthProviderInterface | undefined {
|
||||
return oauthProviderRegistry.get(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a custom OAuth provider
|
||||
*/
|
||||
export function registerOAuthProvider(provider: OAuthProviderInterface): void {
|
||||
oauthProviderRegistry.set(provider.id, provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Unregister an OAuth provider.
|
||||
*
|
||||
* If the provider is built-in, restores the built-in implementation.
|
||||
* Custom providers are removed completely.
|
||||
*/
|
||||
export function unregisterOAuthProvider(id: string): void {
|
||||
const builtInProvider = BUILT_IN_OAUTH_PROVIDERS.find((provider) => provider.id === id);
|
||||
if (builtInProvider) {
|
||||
oauthProviderRegistry.set(id, builtInProvider);
|
||||
return;
|
||||
}
|
||||
oauthProviderRegistry.delete(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset OAuth providers to built-ins.
|
||||
*/
|
||||
export function resetOAuthProviders(): void {
|
||||
oauthProviderRegistry.clear();
|
||||
for (const provider of BUILT_IN_OAUTH_PROVIDERS) {
|
||||
oauthProviderRegistry.set(provider.id, provider);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all registered OAuth providers
|
||||
*/
|
||||
export function getOAuthProviders(): OAuthProviderInterface[] {
|
||||
return Array.from(oauthProviderRegistry.values());
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use getOAuthProviders() which returns OAuthProviderInterface[]
|
||||
*/
|
||||
export function getOAuthProviderInfoList(): OAuthProviderInfo[] {
|
||||
return getOAuthProviders().map((p) => ({
|
||||
id: p.id,
|
||||
name: p.name,
|
||||
available: true,
|
||||
}));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// High-level API (uses provider registry)
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Refresh token for any OAuth provider.
|
||||
* @deprecated Use getOAuthProvider(id).refreshToken() instead
|
||||
*/
|
||||
export async function refreshOAuthToken(
|
||||
providerId: OAuthProviderId,
|
||||
credentials: OAuthCredentials,
|
||||
): Promise<OAuthCredentials> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new Error(`Unknown OAuth provider: ${providerId}`);
|
||||
}
|
||||
return provider.refreshToken(credentials);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a provider from OAuth credentials.
|
||||
* Automatically refreshes expired tokens.
|
||||
*
|
||||
* @returns API key string and updated credentials, or null if no credentials
|
||||
* @throws Error if refresh fails
|
||||
*/
|
||||
export async function getOAuthApiKey(
|
||||
providerId: OAuthProviderId,
|
||||
credentials: Record<string, OAuthCredentials>,
|
||||
): Promise<{ newCredentials: OAuthCredentials; apiKey: string } | null> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new Error(`Unknown OAuth provider: ${providerId}`);
|
||||
}
|
||||
|
||||
let creds = credentials[providerId];
|
||||
if (!creds) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Refresh if expired
|
||||
if (Date.now() >= creds.expires) {
|
||||
try {
|
||||
creds = await provider.refreshToken(creds);
|
||||
} catch (_error) {
|
||||
throw new Error(`Failed to refresh OAuth token for ${providerId}`);
|
||||
}
|
||||
}
|
||||
|
||||
const apiKey = provider.getApiKey(creds);
|
||||
return { newCredentials: creds, apiKey };
|
||||
}
|
||||
109
packages/ai/src/utils/oauth/oauth-page.ts
Normal file
109
packages/ai/src/utils/oauth/oauth-page.ts
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
const LOGO_SVG = `<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 800 800" aria-hidden="true"><path fill="#fff" fill-rule="evenodd" d="M165.29 165.29 H517.36 V400 H400 V517.36 H282.65 V634.72 H165.29 Z M282.65 282.65 V400 H400 V282.65 Z"/><path fill="#fff" d="M517.36 400 H634.72 V634.72 H517.36 Z"/></svg>`;
|
||||
|
||||
function escapeHtml(value: string): string {
|
||||
return value
|
||||
.replaceAll("&", "&")
|
||||
.replaceAll("<", "<")
|
||||
.replaceAll(">", ">")
|
||||
.replaceAll('"', """)
|
||||
.replaceAll("'", "'");
|
||||
}
|
||||
|
||||
function renderPage(options: { title: string; heading: string; message: string; details?: string }): string {
|
||||
const title = escapeHtml(options.title);
|
||||
const heading = escapeHtml(options.heading);
|
||||
const message = escapeHtml(options.message);
|
||||
const details = options.details ? escapeHtml(options.details) : undefined;
|
||||
|
||||
return `<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>${title}</title>
|
||||
<style>
|
||||
:root {
|
||||
--text: #fafafa;
|
||||
--text-dim: #a1a1aa;
|
||||
--page-bg: #09090b;
|
||||
--font-sans: ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, "Noto Sans", sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji";
|
||||
--font-mono: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;
|
||||
}
|
||||
* { box-sizing: border-box; }
|
||||
html { color-scheme: dark; }
|
||||
body {
|
||||
margin: 0;
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 24px;
|
||||
background: var(--page-bg);
|
||||
color: var(--text);
|
||||
font-family: var(--font-sans);
|
||||
text-align: center;
|
||||
}
|
||||
main {
|
||||
width: 100%;
|
||||
max-width: 560px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
.logo {
|
||||
width: 72px;
|
||||
height: 72px;
|
||||
display: block;
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
h1 {
|
||||
margin: 0 0 10px;
|
||||
font-size: 28px;
|
||||
line-height: 1.15;
|
||||
font-weight: 650;
|
||||
color: var(--text);
|
||||
}
|
||||
p {
|
||||
margin: 0;
|
||||
line-height: 1.7;
|
||||
color: var(--text-dim);
|
||||
font-size: 15px;
|
||||
}
|
||||
.details {
|
||||
margin-top: 16px;
|
||||
font-family: var(--font-mono);
|
||||
font-size: 13px;
|
||||
color: var(--text-dim);
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<main>
|
||||
<div class="logo">${LOGO_SVG}</div>
|
||||
<h1>${heading}</h1>
|
||||
<p>${message}</p>
|
||||
${details ? `<div class="details">${details}</div>` : ""}
|
||||
</main>
|
||||
</body>
|
||||
</html>`;
|
||||
}
|
||||
|
||||
export function oauthSuccessHtml(message: string): string {
|
||||
return renderPage({
|
||||
title: "Authentication successful",
|
||||
heading: "Authentication successful",
|
||||
message,
|
||||
});
|
||||
}
|
||||
|
||||
export function oauthErrorHtml(message: string, details?: string): string {
|
||||
return renderPage({
|
||||
title: "Authentication failed",
|
||||
heading: "Authentication failed",
|
||||
message,
|
||||
details,
|
||||
});
|
||||
}
|
||||
450
packages/ai/src/utils/oauth/openai-codex.ts
Normal file
450
packages/ai/src/utils/oauth/openai-codex.ts
Normal file
|
|
@ -0,0 +1,450 @@
|
|||
/**
|
||||
* OpenAI Codex (ChatGPT OAuth) flow
|
||||
*
|
||||
* NOTE: This module uses Node.js crypto and http for the OAuth callback.
|
||||
* It is only intended for CLI use, not browser environments.
|
||||
*/
|
||||
|
||||
// NEVER convert to top-level imports - breaks browser/Vite builds (web-ui)
|
||||
let _randomBytes: typeof import("node:crypto").randomBytes | null = null;
|
||||
let _http: typeof import("node:http") | null = null;
|
||||
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
|
||||
import("node:crypto").then((m) => {
|
||||
_randomBytes = m.randomBytes;
|
||||
});
|
||||
import("node:http").then((m) => {
|
||||
_http = m;
|
||||
});
|
||||
}
|
||||
|
||||
import { oauthErrorHtml, oauthSuccessHtml } from "./oauth-page.js";
|
||||
import { generatePKCE } from "./pkce.js";
|
||||
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthPrompt, OAuthProviderInterface } from "./types.js";
|
||||
|
||||
const CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann";
|
||||
const AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize";
|
||||
const TOKEN_URL = "https://auth.openai.com/oauth/token";
|
||||
const REDIRECT_URI = "http://localhost:1455/auth/callback";
|
||||
const SCOPE = "openid profile email offline_access";
|
||||
const JWT_CLAIM_PATH = "https://api.openai.com/auth";
|
||||
|
||||
type TokenSuccess = { type: "success"; access: string; refresh: string; expires: number };
|
||||
type TokenFailure = { type: "failed" };
|
||||
type TokenResult = TokenSuccess | TokenFailure;
|
||||
|
||||
type JwtPayload = {
|
||||
[JWT_CLAIM_PATH]?: {
|
||||
chatgpt_account_id?: string;
|
||||
};
|
||||
[key: string]: unknown;
|
||||
};
|
||||
|
||||
function createState(): string {
|
||||
if (!_randomBytes) {
|
||||
throw new Error("OpenAI Codex OAuth is only available in Node.js environments");
|
||||
}
|
||||
return _randomBytes(16).toString("hex");
|
||||
}
|
||||
|
||||
function parseAuthorizationInput(input: string): { code?: string; state?: string } {
|
||||
const value = input.trim();
|
||||
if (!value) return {};
|
||||
|
||||
try {
|
||||
const url = new URL(value);
|
||||
return {
|
||||
code: url.searchParams.get("code") ?? undefined,
|
||||
state: url.searchParams.get("state") ?? undefined,
|
||||
};
|
||||
} catch {
|
||||
// not a URL
|
||||
}
|
||||
|
||||
if (value.includes("#")) {
|
||||
const [code, state] = value.split("#", 2);
|
||||
return { code, state };
|
||||
}
|
||||
|
||||
if (value.includes("code=")) {
|
||||
const params = new URLSearchParams(value);
|
||||
return {
|
||||
code: params.get("code") ?? undefined,
|
||||
state: params.get("state") ?? undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return { code: value };
|
||||
}
|
||||
|
||||
function decodeJwt(token: string): JwtPayload | null {
|
||||
try {
|
||||
const parts = token.split(".");
|
||||
if (parts.length !== 3) return null;
|
||||
const payload = parts[1] ?? "";
|
||||
const decoded = atob(payload);
|
||||
return JSON.parse(decoded) as JwtPayload;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
async function exchangeAuthorizationCode(
|
||||
code: string,
|
||||
verifier: string,
|
||||
redirectUri: string = REDIRECT_URI,
|
||||
): Promise<TokenResult> {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: new URLSearchParams({
|
||||
grant_type: "authorization_code",
|
||||
client_id: CLIENT_ID,
|
||||
code,
|
||||
code_verifier: verifier,
|
||||
redirect_uri: redirectUri,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text().catch(() => "");
|
||||
console.error("[openai-codex] code->token failed:", response.status, text);
|
||||
return { type: "failed" };
|
||||
}
|
||||
|
||||
const json = (await response.json()) as {
|
||||
access_token?: string;
|
||||
refresh_token?: string;
|
||||
expires_in?: number;
|
||||
};
|
||||
|
||||
if (!json.access_token || !json.refresh_token || typeof json.expires_in !== "number") {
|
||||
console.error("[openai-codex] token response missing fields:", json);
|
||||
return { type: "failed" };
|
||||
}
|
||||
|
||||
return {
|
||||
type: "success",
|
||||
access: json.access_token,
|
||||
refresh: json.refresh_token,
|
||||
expires: Date.now() + json.expires_in * 1000,
|
||||
};
|
||||
}
|
||||
|
||||
async function refreshAccessToken(refreshToken: string): Promise<TokenResult> {
|
||||
try {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: new URLSearchParams({
|
||||
grant_type: "refresh_token",
|
||||
refresh_token: refreshToken,
|
||||
client_id: CLIENT_ID,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text().catch(() => "");
|
||||
console.error("[openai-codex] Token refresh failed:", response.status, text);
|
||||
return { type: "failed" };
|
||||
}
|
||||
|
||||
const json = (await response.json()) as {
|
||||
access_token?: string;
|
||||
refresh_token?: string;
|
||||
expires_in?: number;
|
||||
};
|
||||
|
||||
if (!json.access_token || !json.refresh_token || typeof json.expires_in !== "number") {
|
||||
console.error("[openai-codex] Token refresh response missing fields:", json);
|
||||
return { type: "failed" };
|
||||
}
|
||||
|
||||
return {
|
||||
type: "success",
|
||||
access: json.access_token,
|
||||
refresh: json.refresh_token,
|
||||
expires: Date.now() + json.expires_in * 1000,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("[openai-codex] Token refresh error:", error);
|
||||
return { type: "failed" };
|
||||
}
|
||||
}
|
||||
|
||||
async function createAuthorizationFlow(
|
||||
originator: string = "pi",
|
||||
): Promise<{ verifier: string; state: string; url: string }> {
|
||||
const { verifier, challenge } = await generatePKCE();
|
||||
const state = createState();
|
||||
|
||||
const url = new URL(AUTHORIZE_URL);
|
||||
url.searchParams.set("response_type", "code");
|
||||
url.searchParams.set("client_id", CLIENT_ID);
|
||||
url.searchParams.set("redirect_uri", REDIRECT_URI);
|
||||
url.searchParams.set("scope", SCOPE);
|
||||
url.searchParams.set("code_challenge", challenge);
|
||||
url.searchParams.set("code_challenge_method", "S256");
|
||||
url.searchParams.set("state", state);
|
||||
url.searchParams.set("id_token_add_organizations", "true");
|
||||
url.searchParams.set("codex_cli_simplified_flow", "true");
|
||||
url.searchParams.set("originator", originator);
|
||||
|
||||
return { verifier, state, url: url.toString() };
|
||||
}
|
||||
|
||||
type OAuthServerInfo = {
|
||||
close: () => void;
|
||||
cancelWait: () => void;
|
||||
waitForCode: () => Promise<{ code: string } | null>;
|
||||
};
|
||||
|
||||
function startLocalOAuthServer(state: string): Promise<OAuthServerInfo> {
|
||||
if (!_http) {
|
||||
throw new Error("OpenAI Codex OAuth is only available in Node.js environments");
|
||||
}
|
||||
|
||||
let settleWait: ((value: { code: string } | null) => void) | undefined;
|
||||
const waitForCodePromise = new Promise<{ code: string } | null>((resolve) => {
|
||||
let settled = false;
|
||||
settleWait = (value) => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
resolve(value);
|
||||
};
|
||||
});
|
||||
|
||||
const server = _http.createServer((req, res) => {
|
||||
try {
|
||||
const url = new URL(req.url || "", "http://localhost");
|
||||
if (url.pathname !== "/auth/callback") {
|
||||
res.statusCode = 404;
|
||||
res.setHeader("Content-Type", "text/html; charset=utf-8");
|
||||
res.end(oauthErrorHtml("Callback route not found."));
|
||||
return;
|
||||
}
|
||||
if (url.searchParams.get("state") !== state) {
|
||||
res.statusCode = 400;
|
||||
res.setHeader("Content-Type", "text/html; charset=utf-8");
|
||||
res.end(oauthErrorHtml("State mismatch."));
|
||||
return;
|
||||
}
|
||||
const code = url.searchParams.get("code");
|
||||
if (!code) {
|
||||
res.statusCode = 400;
|
||||
res.setHeader("Content-Type", "text/html; charset=utf-8");
|
||||
res.end(oauthErrorHtml("Missing authorization code."));
|
||||
return;
|
||||
}
|
||||
res.statusCode = 200;
|
||||
res.setHeader("Content-Type", "text/html; charset=utf-8");
|
||||
res.end(oauthSuccessHtml("OpenAI authentication completed. You can close this window."));
|
||||
settleWait?.({ code });
|
||||
} catch {
|
||||
res.statusCode = 500;
|
||||
res.setHeader("Content-Type", "text/html; charset=utf-8");
|
||||
res.end(oauthErrorHtml("Internal error while processing OAuth callback."));
|
||||
}
|
||||
});
|
||||
|
||||
return new Promise((resolve) => {
|
||||
server
|
||||
.listen(1455, "127.0.0.1", () => {
|
||||
resolve({
|
||||
close: () => server.close(),
|
||||
cancelWait: () => {
|
||||
settleWait?.(null);
|
||||
},
|
||||
waitForCode: () => waitForCodePromise,
|
||||
});
|
||||
})
|
||||
.on("error", (err: NodeJS.ErrnoException) => {
|
||||
console.error(
|
||||
"[openai-codex] Failed to bind http://127.0.0.1:1455 (",
|
||||
err.code,
|
||||
") Falling back to manual paste.",
|
||||
);
|
||||
settleWait?.(null);
|
||||
resolve({
|
||||
close: () => {
|
||||
try {
|
||||
server.close();
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
},
|
||||
cancelWait: () => {},
|
||||
waitForCode: async () => null,
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function getAccountId(accessToken: string): string | null {
|
||||
const payload = decodeJwt(accessToken);
|
||||
const auth = payload?.[JWT_CLAIM_PATH];
|
||||
const accountId = auth?.chatgpt_account_id;
|
||||
return typeof accountId === "string" && accountId.length > 0 ? accountId : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Login with OpenAI Codex OAuth
|
||||
*
|
||||
* @param options.onAuth - Called with URL and instructions when auth starts
|
||||
* @param options.onPrompt - Called to prompt user for manual code paste (fallback if no onManualCodeInput)
|
||||
* @param options.onProgress - Optional progress messages
|
||||
* @param options.onManualCodeInput - Optional promise that resolves with user-pasted code.
|
||||
* Races with browser callback - whichever completes first wins.
|
||||
* Useful for showing paste input immediately alongside browser flow.
|
||||
* @param options.originator - OAuth originator parameter (defaults to "pi")
|
||||
*/
|
||||
export async function loginOpenAICodex(options: {
|
||||
onAuth: (info: { url: string; instructions?: string }) => void;
|
||||
onPrompt: (prompt: OAuthPrompt) => Promise<string>;
|
||||
onProgress?: (message: string) => void;
|
||||
onManualCodeInput?: () => Promise<string>;
|
||||
originator?: string;
|
||||
}): Promise<OAuthCredentials> {
|
||||
const { verifier, state, url } = await createAuthorizationFlow(options.originator);
|
||||
const server = await startLocalOAuthServer(state);
|
||||
|
||||
options.onAuth({ url, instructions: "A browser window should open. Complete login to finish." });
|
||||
|
||||
let code: string | undefined;
|
||||
try {
|
||||
if (options.onManualCodeInput) {
|
||||
// Race between browser callback and manual input
|
||||
let manualCode: string | undefined;
|
||||
let manualError: Error | undefined;
|
||||
const manualPromise = options
|
||||
.onManualCodeInput()
|
||||
.then((input) => {
|
||||
manualCode = input;
|
||||
server.cancelWait();
|
||||
})
|
||||
.catch((err) => {
|
||||
manualError = err instanceof Error ? err : new Error(String(err));
|
||||
server.cancelWait();
|
||||
});
|
||||
|
||||
const result = await server.waitForCode();
|
||||
|
||||
// If manual input was cancelled, throw that error
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
|
||||
if (result?.code) {
|
||||
// Browser callback won
|
||||
code = result.code;
|
||||
} else if (manualCode) {
|
||||
// Manual input won (or callback timed out and user had entered code)
|
||||
const parsed = parseAuthorizationInput(manualCode);
|
||||
if (parsed.state && parsed.state !== state) {
|
||||
throw new Error("State mismatch");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
|
||||
// If still no code, wait for manual promise to complete and try that
|
||||
if (!code) {
|
||||
await manualPromise;
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
if (manualCode) {
|
||||
const parsed = parseAuthorizationInput(manualCode);
|
||||
if (parsed.state && parsed.state !== state) {
|
||||
throw new Error("State mismatch");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Original flow: wait for callback, then prompt if needed
|
||||
const result = await server.waitForCode();
|
||||
if (result?.code) {
|
||||
code = result.code;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to onPrompt if still no code
|
||||
if (!code) {
|
||||
const input = await options.onPrompt({
|
||||
message: "Paste the authorization code (or full redirect URL):",
|
||||
});
|
||||
const parsed = parseAuthorizationInput(input);
|
||||
if (parsed.state && parsed.state !== state) {
|
||||
throw new Error("State mismatch");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
|
||||
if (!code) {
|
||||
throw new Error("Missing authorization code");
|
||||
}
|
||||
|
||||
const tokenResult = await exchangeAuthorizationCode(code, verifier);
|
||||
if (tokenResult.type !== "success") {
|
||||
throw new Error("Token exchange failed");
|
||||
}
|
||||
|
||||
const accountId = getAccountId(tokenResult.access);
|
||||
if (!accountId) {
|
||||
throw new Error("Failed to extract accountId from token");
|
||||
}
|
||||
|
||||
return {
|
||||
access: tokenResult.access,
|
||||
refresh: tokenResult.refresh,
|
||||
expires: tokenResult.expires,
|
||||
accountId,
|
||||
};
|
||||
} finally {
|
||||
server.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh OpenAI Codex OAuth token
|
||||
*/
|
||||
export async function refreshOpenAICodexToken(refreshToken: string): Promise<OAuthCredentials> {
|
||||
const result = await refreshAccessToken(refreshToken);
|
||||
if (result.type !== "success") {
|
||||
throw new Error("Failed to refresh OpenAI Codex token");
|
||||
}
|
||||
|
||||
const accountId = getAccountId(result.access);
|
||||
if (!accountId) {
|
||||
throw new Error("Failed to extract accountId from token");
|
||||
}
|
||||
|
||||
return {
|
||||
access: result.access,
|
||||
refresh: result.refresh,
|
||||
expires: result.expires,
|
||||
accountId,
|
||||
};
|
||||
}
|
||||
|
||||
export const openaiCodexOAuthProvider: OAuthProviderInterface = {
|
||||
id: "openai-codex",
|
||||
name: "ChatGPT Plus/Pro (Codex Subscription)",
|
||||
usesCallbackServer: true,
|
||||
|
||||
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
return loginOpenAICodex({
|
||||
onAuth: callbacks.onAuth,
|
||||
onPrompt: callbacks.onPrompt,
|
||||
onProgress: callbacks.onProgress,
|
||||
onManualCodeInput: callbacks.onManualCodeInput,
|
||||
});
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
return refreshOpenAICodexToken(credentials.refresh);
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
return credentials.access;
|
||||
},
|
||||
};
|
||||
34
packages/ai/src/utils/oauth/pkce.ts
Normal file
34
packages/ai/src/utils/oauth/pkce.ts
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* PKCE utilities using Web Crypto API.
|
||||
* Works in both Node.js 20+ and browsers.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Encode bytes as base64url string.
|
||||
*/
|
||||
function base64urlEncode(bytes: Uint8Array): string {
|
||||
let binary = "";
|
||||
for (const byte of bytes) {
|
||||
binary += String.fromCharCode(byte);
|
||||
}
|
||||
return btoa(binary).replace(/\+/g, "-").replace(/\//g, "_").replace(/=/g, "");
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate PKCE code verifier and challenge.
|
||||
* Uses Web Crypto API for cross-platform compatibility.
|
||||
*/
|
||||
export async function generatePKCE(): Promise<{ verifier: string; challenge: string }> {
|
||||
// Generate random verifier
|
||||
const verifierBytes = new Uint8Array(32);
|
||||
crypto.getRandomValues(verifierBytes);
|
||||
const verifier = base64urlEncode(verifierBytes);
|
||||
|
||||
// Compute SHA-256 challenge
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(verifier);
|
||||
const hashBuffer = await crypto.subtle.digest("SHA-256", data);
|
||||
const challenge = base64urlEncode(new Uint8Array(hashBuffer));
|
||||
|
||||
return { verifier, challenge };
|
||||
}
|
||||
59
packages/ai/src/utils/oauth/types.ts
Normal file
59
packages/ai/src/utils/oauth/types.ts
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import type { Api, Model } from "../../types.js";
|
||||
|
||||
export type OAuthCredentials = {
|
||||
refresh: string;
|
||||
access: string;
|
||||
expires: number;
|
||||
[key: string]: unknown;
|
||||
};
|
||||
|
||||
export type OAuthProviderId = string;
|
||||
|
||||
/** @deprecated Use OAuthProviderId instead */
|
||||
export type OAuthProvider = OAuthProviderId;
|
||||
|
||||
export type OAuthPrompt = {
|
||||
message: string;
|
||||
placeholder?: string;
|
||||
allowEmpty?: boolean;
|
||||
};
|
||||
|
||||
export type OAuthAuthInfo = {
|
||||
url: string;
|
||||
instructions?: string;
|
||||
};
|
||||
|
||||
export interface OAuthLoginCallbacks {
|
||||
onAuth: (info: OAuthAuthInfo) => void;
|
||||
onPrompt: (prompt: OAuthPrompt) => Promise<string>;
|
||||
onProgress?: (message: string) => void;
|
||||
onManualCodeInput?: () => Promise<string>;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface OAuthProviderInterface {
|
||||
readonly id: OAuthProviderId;
|
||||
readonly name: string;
|
||||
|
||||
/** Run the login flow, return credentials to persist */
|
||||
login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials>;
|
||||
|
||||
/** Whether login uses a local callback server and supports manual code input. */
|
||||
usesCallbackServer?: boolean;
|
||||
|
||||
/** Refresh expired credentials, return updated credentials to persist */
|
||||
refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials>;
|
||||
|
||||
/** Convert credentials to API key string for the provider */
|
||||
getApiKey(credentials: OAuthCredentials): string;
|
||||
|
||||
/** Optional: modify models for this provider (e.g., update baseUrl) */
|
||||
modifyModels?(models: Model<Api>[], credentials: OAuthCredentials): Model<Api>[];
|
||||
}
|
||||
|
||||
/** @deprecated Use OAuthProviderInterface instead */
|
||||
export interface OAuthProviderInfo {
|
||||
id: OAuthProviderId;
|
||||
name: string;
|
||||
available: boolean;
|
||||
}
|
||||
123
packages/ai/src/utils/overflow.ts
Normal file
123
packages/ai/src/utils/overflow.ts
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
import type { AssistantMessage } from "../types.js";
|
||||
|
||||
/**
|
||||
* Regex patterns to detect context overflow errors from different providers.
|
||||
*
|
||||
* These patterns match error messages returned when the input exceeds
|
||||
* the model's context window.
|
||||
*
|
||||
* Provider-specific patterns (with example error messages):
|
||||
*
|
||||
* - Anthropic: "prompt is too long: 213462 tokens > 200000 maximum"
|
||||
* - OpenAI: "Your input exceeds the context window of this model"
|
||||
* - Google: "The input token count (1196265) exceeds the maximum number of tokens allowed (1048575)"
|
||||
* - xAI: "This model's maximum prompt length is 131072 but the request contains 537812 tokens"
|
||||
* - Groq: "Please reduce the length of the messages or completion"
|
||||
* - OpenRouter: "This endpoint's maximum context length is X tokens. However, you requested about Y tokens"
|
||||
* - llama.cpp: "the request exceeds the available context size, try increasing it"
|
||||
* - LM Studio: "tokens to keep from the initial prompt is greater than the context length"
|
||||
* - GitHub Copilot: "prompt token count of X exceeds the limit of Y"
|
||||
* - MiniMax: "invalid params, context window exceeds limit"
|
||||
* - Kimi For Coding: "Your request exceeded model token limit: X (requested: Y)"
|
||||
* - Cerebras: Returns "400/413 status code (no body)" - handled separately below
|
||||
* - Mistral: "Prompt contains X tokens ... too large for model with Y maximum context length"
|
||||
* - z.ai: Does NOT error, accepts overflow silently - handled via usage.input > contextWindow
|
||||
* - Ollama: Silently truncates input - not detectable via error message
|
||||
*/
|
||||
const OVERFLOW_PATTERNS = [
|
||||
/prompt is too long/i, // Anthropic
|
||||
/input is too long for requested model/i, // Amazon Bedrock
|
||||
/exceeds the context window/i, // OpenAI (Completions & Responses API)
|
||||
/input token count.*exceeds the maximum/i, // Google (Gemini)
|
||||
/maximum prompt length is \d+/i, // xAI (Grok)
|
||||
/reduce the length of the messages/i, // Groq
|
||||
/maximum context length is \d+ tokens/i, // OpenRouter (all backends)
|
||||
/exceeds the limit of \d+/i, // GitHub Copilot
|
||||
/exceeds the available context size/i, // llama.cpp server
|
||||
/greater than the context length/i, // LM Studio
|
||||
/context window exceeds limit/i, // MiniMax
|
||||
/exceeded model token limit/i, // Kimi For Coding
|
||||
/too large for model with \d+ maximum context length/i, // Mistral
|
||||
/model_context_window_exceeded/i, // z.ai non-standard finish_reason surfaced as error text
|
||||
/context[_ ]length[_ ]exceeded/i, // Generic fallback
|
||||
/too many tokens/i, // Generic fallback
|
||||
/token limit exceeded/i, // Generic fallback
|
||||
];
|
||||
|
||||
/**
|
||||
* Check if an assistant message represents a context overflow error.
|
||||
*
|
||||
* This handles two cases:
|
||||
* 1. Error-based overflow: Most providers return stopReason "error" with a
|
||||
* specific error message pattern.
|
||||
* 2. Silent overflow: Some providers accept overflow requests and return
|
||||
* successfully. For these, we check if usage.input exceeds the context window.
|
||||
*
|
||||
* ## Reliability by Provider
|
||||
*
|
||||
* **Reliable detection (returns error with detectable message):**
|
||||
* - Anthropic: "prompt is too long: X tokens > Y maximum"
|
||||
* - OpenAI (Completions & Responses): "exceeds the context window"
|
||||
* - Google Gemini: "input token count exceeds the maximum"
|
||||
* - xAI (Grok): "maximum prompt length is X but request contains Y"
|
||||
* - Groq: "reduce the length of the messages"
|
||||
* - Cerebras: 400/413 status code (no body)
|
||||
* - Mistral: "Prompt contains X tokens ... too large for model with Y maximum context length"
|
||||
* - OpenRouter (all backends): "maximum context length is X tokens"
|
||||
* - llama.cpp: "exceeds the available context size"
|
||||
* - LM Studio: "greater than the context length"
|
||||
* - Kimi For Coding: "exceeded model token limit: X (requested: Y)"
|
||||
*
|
||||
* **Unreliable detection:**
|
||||
* - z.ai: Sometimes accepts overflow silently (detectable via usage.input > contextWindow),
|
||||
* sometimes returns rate limit errors. Pass contextWindow param to detect silent overflow.
|
||||
* - Ollama: Silently truncates input without error. Cannot be detected via this function.
|
||||
* The response will have usage.input < expected, but we don't know the expected value.
|
||||
*
|
||||
* ## Custom Providers
|
||||
*
|
||||
* If you've added custom models via settings.json, this function may not detect
|
||||
* overflow errors from those providers. To add support:
|
||||
*
|
||||
* 1. Send a request that exceeds the model's context window
|
||||
* 2. Check the errorMessage in the response
|
||||
* 3. Create a regex pattern that matches the error
|
||||
* 4. The pattern should be added to OVERFLOW_PATTERNS in this file, or
|
||||
* check the errorMessage yourself before calling this function
|
||||
*
|
||||
* @param message - The assistant message to check
|
||||
* @param contextWindow - Optional context window size for detecting silent overflow (z.ai)
|
||||
* @returns true if the message indicates a context overflow
|
||||
*/
|
||||
export function isContextOverflow(message: AssistantMessage, contextWindow?: number): boolean {
|
||||
// Case 1: Check error message patterns
|
||||
if (message.stopReason === "error" && message.errorMessage) {
|
||||
// Check known patterns
|
||||
if (OVERFLOW_PATTERNS.some((p) => p.test(message.errorMessage!))) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Cerebras returns 400/413 with no body for context overflow
|
||||
// Note: 429 is rate limiting (requests/tokens per time), NOT context overflow
|
||||
if (/^4(00|13)\s*(status code)?\s*\(no body\)/i.test(message.errorMessage)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Case 2: Silent overflow (z.ai style) - successful but usage exceeds context
|
||||
if (contextWindow && message.stopReason === "stop") {
|
||||
const inputTokens = message.usage.input + message.usage.cacheRead;
|
||||
if (inputTokens > contextWindow) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the overflow patterns for testing purposes.
|
||||
*/
|
||||
export function getOverflowPatterns(): RegExp[] {
|
||||
return [...OVERFLOW_PATTERNS];
|
||||
}
|
||||
25
packages/ai/src/utils/sanitize-unicode.ts
Normal file
25
packages/ai/src/utils/sanitize-unicode.ts
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Removes unpaired Unicode surrogate characters from a string.
|
||||
*
|
||||
* Unpaired surrogates (high surrogates 0xD800-0xDBFF without matching low surrogates 0xDC00-0xDFFF,
|
||||
* or vice versa) cause JSON serialization errors in many API providers.
|
||||
*
|
||||
* Valid emoji and other characters outside the Basic Multilingual Plane use properly paired
|
||||
* surrogates and will NOT be affected by this function.
|
||||
*
|
||||
* @param text - The text to sanitize
|
||||
* @returns The sanitized text with unpaired surrogates removed
|
||||
*
|
||||
* @example
|
||||
* // Valid emoji (properly paired surrogates) are preserved
|
||||
* sanitizeSurrogates("Hello 🙈 World") // => "Hello 🙈 World"
|
||||
*
|
||||
* // Unpaired high surrogate is removed
|
||||
* const unpaired = String.fromCharCode(0xD83D); // high surrogate without low
|
||||
* sanitizeSurrogates(`Text ${unpaired} here`) // => "Text here"
|
||||
*/
|
||||
export function sanitizeSurrogates(text: string): string {
|
||||
// Replace unpaired high surrogates (0xD800-0xDBFF not followed by low surrogate)
|
||||
// Replace unpaired low surrogates (0xDC00-0xDFFF not preceded by high surrogate)
|
||||
return text.replace(/[\uD800-\uDBFF](?![\uDC00-\uDFFF])|(?<![\uD800-\uDBFF])[\uDC00-\uDFFF]/g, "");
|
||||
}
|
||||
24
packages/ai/src/utils/typebox-helpers.ts
Normal file
24
packages/ai/src/utils/typebox-helpers.ts
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
import { type TUnsafe, Type } from "@sinclair/typebox";
|
||||
|
||||
/**
|
||||
* Creates a string enum schema compatible with Google's API and other providers
|
||||
* that don't support anyOf/const patterns.
|
||||
*
|
||||
* @example
|
||||
* const OperationSchema = StringEnum(["add", "subtract", "multiply", "divide"], {
|
||||
* description: "The operation to perform"
|
||||
* });
|
||||
*
|
||||
* type Operation = Static<typeof OperationSchema>; // "add" | "subtract" | "multiply" | "divide"
|
||||
*/
|
||||
export function StringEnum<T extends readonly string[]>(
|
||||
values: T,
|
||||
options?: { description?: string; default?: T[number] },
|
||||
): TUnsafe<T[number]> {
|
||||
return Type.Unsafe<T[number]>({
|
||||
type: "string",
|
||||
enum: values as any,
|
||||
...(options?.description && { description: options.description }),
|
||||
...(options?.default && { default: options.default }),
|
||||
});
|
||||
}
|
||||
93
packages/ai/src/utils/validation.ts
Normal file
93
packages/ai/src/utils/validation.ts
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
import AjvModule from "ajv";
|
||||
import addFormatsModule from "ajv-formats";
|
||||
|
||||
// Handle both default and named exports
|
||||
const Ajv = (AjvModule as any).default || AjvModule;
|
||||
const addFormats = (addFormatsModule as any).default || addFormatsModule;
|
||||
|
||||
import type { Tool, ToolCall } from "../types.js";
|
||||
|
||||
// Detect if we're in a browser extension environment with strict CSP
|
||||
// Chrome extensions with Manifest V3 don't allow eval/Function constructor
|
||||
const isBrowserExtension = typeof globalThis !== "undefined" && (globalThis as any).chrome?.runtime?.id !== undefined;
|
||||
|
||||
function canUseRuntimeCodegen(): boolean {
|
||||
if (isBrowserExtension) {
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
new Function("return true;");
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a singleton AJV instance with formats only when runtime code generation is available.
|
||||
let ajv: any = null;
|
||||
if (canUseRuntimeCodegen()) {
|
||||
try {
|
||||
ajv = new Ajv({
|
||||
allErrors: true,
|
||||
strict: false,
|
||||
coerceTypes: true,
|
||||
});
|
||||
addFormats(ajv);
|
||||
} catch (_e) {
|
||||
console.warn("AJV validation disabled due to CSP restrictions");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds a tool by name and validates the tool call arguments against its TypeBox schema
|
||||
* @param tools Array of tool definitions
|
||||
* @param toolCall The tool call from the LLM
|
||||
* @returns The validated arguments
|
||||
* @throws Error if tool is not found or validation fails
|
||||
*/
|
||||
export function validateToolCall(tools: Tool[], toolCall: ToolCall): any {
|
||||
const tool = tools.find((t) => t.name === toolCall.name);
|
||||
if (!tool) {
|
||||
throw new Error(`Tool "${toolCall.name}" not found`);
|
||||
}
|
||||
return validateToolArguments(tool, toolCall);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates tool call arguments against the tool's TypeBox schema
|
||||
* @param tool The tool definition with TypeBox schema
|
||||
* @param toolCall The tool call from the LLM
|
||||
* @returns The validated (and potentially coerced) arguments
|
||||
* @throws Error with formatted message if validation fails
|
||||
*/
|
||||
export function validateToolArguments(tool: Tool, toolCall: ToolCall): any {
|
||||
// Skip validation in environments where runtime code generation is unavailable.
|
||||
if (!ajv || !canUseRuntimeCodegen()) {
|
||||
return toolCall.arguments;
|
||||
}
|
||||
|
||||
// Compile the schema.
|
||||
const validate = ajv.compile(tool.parameters);
|
||||
|
||||
// Clone arguments so AJV can safely mutate for type coercion
|
||||
const args = structuredClone(toolCall.arguments);
|
||||
|
||||
// Validate the arguments (AJV mutates args in-place for type coercion)
|
||||
if (validate(args)) {
|
||||
return args;
|
||||
}
|
||||
|
||||
// Format validation errors nicely
|
||||
const errors =
|
||||
validate.errors
|
||||
?.map((err: any) => {
|
||||
const path = err.instancePath ? err.instancePath.substring(1) : err.params.missingProperty || "root";
|
||||
return ` - ${path}: ${err.message}`;
|
||||
})
|
||||
.join("\n") || "Unknown validation error";
|
||||
|
||||
const errorMessage = `Validation failed for tool "${toolCall.name}":\n${errors}\n\nReceived arguments:\n${JSON.stringify(toolCall.arguments, null, 2)}`;
|
||||
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue