add /detect route for language detection

This commit is contained in:
Chris W 2023-10-30 20:41:16 -06:00
parent 4ee081df66
commit 4309b55aa4
2 changed files with 92 additions and 10 deletions

View File

@ -1,4 +1,5 @@
use anyhow::Error; use anyhow::Error;
use std::collections::HashMap;
use silicon::formatter::{ImageFormatter, ImageFormatterBuilder}; use silicon::formatter::{ImageFormatter, ImageFormatterBuilder};
use silicon::utils::{Background, ShadowAdder}; use silicon::utils::{Background, ShadowAdder};
use std::path::PathBuf; use std::path::PathBuf;
@ -11,6 +12,15 @@ use crate::rgba::{ImageRgba, Rgba};
type FontList = Vec<(String, f32)>; type FontList = Vec<(String, f32)>;
type Lines = Vec<u32>; type Lines = Vec<u32>;
macro_rules! unwrap_or_return {
( $e:expr, $r:expr ) => {
match $e {
Ok(x) => x,
Err(_) => return $r,
}
};
}
#[derive(Debug, serde::Deserialize)] #[derive(Debug, serde::Deserialize)]
pub struct Config { pub struct Config {
/// Background image URL /// Background image URL
@ -132,15 +142,25 @@ impl Config {
ps.find_syntax_by_first_line(first_line).unwrap_or_else(|| { ps.find_syntax_by_first_line(first_line).unwrap_or_else(|| {
// Try using tensorflow to detect the language // Try using tensorflow to detect the language
let input_data = Tensor::new(&[1]).with_values(&[self.code.clone()]).unwrap(); let input_data = Tensor::new(&[1]).with_values(&[self.code.clone()]).unwrap();
self.predict_language_with_tensorflow(ps, input_data) let predictions = self.predict_language_with_tensorflow(ps, input_data).unwrap();
.unwrap_or_else(|_| ps.find_syntax_by_token("log").unwrap())
let mut max_score = -std::f32::INFINITY;
let mut max_language = "log";
for (language, score) in &predictions { // Borrow predictions here
if *score > max_score {
max_score = *score;
max_language = language;
}
}
ps.find_syntax_by_token(max_language).unwrap_or_else(|| ps.find_syntax_by_token("log").unwrap())
}) })
}, },
}; };
Ok(language) Ok(language)
} }
pub fn predict_language_with_tensorflow<'a>(&self, ps: &'a SyntaxSet, input_data: Tensor<String>) -> Result<&'a SyntaxReference, Error> { pub fn predict_language_with_tensorflow<'a>(&self, ps: &'a SyntaxSet, input_data: Tensor<String>) -> Result<HashMap<String, f32>, Error> {
if self.tf_model_graph.is_none() || self.tf_model.is_none() { if self.tf_model_graph.is_none() || self.tf_model.is_none() {
return Err(Error::msg("TensorFlow model not loaded")); return Err(Error::msg("TensorFlow model not loaded"));
} }
@ -165,14 +185,14 @@ impl Config {
let classes: Tensor<String> = args.fetch(output_token_classes)?; let classes: Tensor<String> = args.fetch(output_token_classes)?;
// Find the index of the highest score let mut result: HashMap<String, f32> = HashMap::new();
let max_index = scores.iter().enumerate().max_by(|a, b| a.1.partial_cmp(b.1).unwrap()).unwrap().0; for (i, score) in scores.iter().enumerate() {
let class = classes[i].clone();
let log_score = score.log2();
result.insert(class, log_score);
}
let language = &classes[max_index]; Ok(result)
let language = ps.find_syntax_by_token(language).unwrap();
Ok(language)
} }

View File

@ -7,6 +7,7 @@ use anyhow::Error;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use silicon as si; use silicon as si;
use silicon::utils::ToRgba; use silicon::utils::ToRgba;
use tensorflow::Tensor;
use std::collections::HashSet; use std::collections::HashSet;
use std::io::Cursor; use std::io::Cursor;
use std::num::ParseIntError; use std::num::ParseIntError;
@ -20,6 +21,7 @@ lazy_static! {
static ref HIGHLIGHTING_ASSETS: si::assets::HighlightingAssets = static ref HIGHLIGHTING_ASSETS: si::assets::HighlightingAssets =
silicon::assets::HighlightingAssets::new(); silicon::assets::HighlightingAssets::new();
} }
macro_rules! unwrap_or_return { macro_rules! unwrap_or_return {
( $e:expr, $r:expr ) => { ( $e:expr, $r:expr ) => {
match $e { match $e {
@ -87,6 +89,12 @@ async fn help() -> impl Responder {
"GET /themes": "Return a list of available syntax themes.", "GET /themes": "Return a list of available syntax themes.",
"GET /languages": "Retuns a list of languages which can be parsed.", "GET /languages": "Retuns a list of languages which can be parsed.",
"GET /fonts": "Returns a list of available fonts.", "GET /fonts": "Returns a list of available fonts.",
"GET /detect": {
"description": "Detect the language of the given code.",
"parameters": {
"code": "The code to detect the language of. Required."
}
},
"GET /generate": { "GET /generate": {
"description": "Generate an image from the given code.", "description": "Generate an image from the given code.",
"parameters": { "parameters": {
@ -150,6 +158,59 @@ async fn fonts() -> impl Responder {
HttpResponse::Ok().json(fonts) HttpResponse::Ok().json(fonts)
} }
#[get("/detect")]
async fn detect(info: web::Query<config::ConfigQuery>) -> impl Responder {
let args = CliArgs::parse();
let ha = &*HIGHLIGHTING_ASSETS;
let (ps, _ts) = (&ha.syntax_set, &ha.theme_set);
let mut conf = config::Config::default();
conf.code = info.code.clone();
if conf.code.is_empty() {
return HttpResponse::BadRequest()
.append_header(("Content-Type", "application/json"))
.body(r#"{"error": "code parameter is required"}"#);
}
if args.tensorflow_model_dir.is_some() {
conf.load_tensorflow_model(args.tensorflow_model_dir.unwrap().as_str());
}
let input_data = Tensor::new(&[1]).with_values(&[conf.code.clone()]).unwrap();
let predictions = unwrap_or_return!(
conf.predict_language_with_tensorflow(ps, input_data),
HttpResponse::BadRequest()
.append_header(("Content-Type", "application/json"))
.body(r#"{"error": "Failed to detect language."}"#)
);
let mut sorted_predictions: Vec<_> = predictions.iter().collect();
sorted_predictions.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
let min_score = predictions.iter().map(|(_, score)| *score).fold(f32::INFINITY, f32::min);
let max_score = predictions.iter().map(|(_, score)| *score).fold(f32::NEG_INFINITY, f32::max);
// Normalize scores and pick top 5
let mut normalized_predictions: Vec<_> = predictions.iter().map(|(lang, score)| {
let normalized_score = (score - min_score) / (max_score - min_score) * 100.0;
(lang, normalized_score)
}).collect();
normalized_predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let response = normalized_predictions
.iter()
// .take(5)
.map(|(language, score)| format!("{{\"language\": \"{}\", \"score\": {}}}", language, score))
.collect::<Vec<_>>()
.join(",");
HttpResponse::Ok()
.append_header(("Content-Type", "application/json"))
.body(format!("[{}]", response))
}
#[get("/generate")] #[get("/generate")]
async fn generate(info: web::Query<config::ConfigQuery>) -> impl Responder { async fn generate(info: web::Query<config::ConfigQuery>) -> impl Responder {
let args = CliArgs::parse(); let args = CliArgs::parse();
@ -290,6 +351,7 @@ async fn main() -> std::io::Result<()> {
.service(themes) .service(themes)
.service(languages) .service(languages)
.service(fonts) .service(fonts)
.service(detect)
.service(generate) .service(generate)
}) })
.bind((host.clone(), port.parse::<u16>().unwrap()))? .bind((host.clone(), port.parse::<u16>().unwrap()))?